In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
    

class SRMConv2d_simple(nn.Module):
    
    def __init__(self, inc=3, learnable=False):
        super(SRMConv2d_simple, self).__init__()
        self.truc = nn.Hardtanh(-3, 3)
        kernel = self._build_kernel(inc)  # (3,3,5,5)
        self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
        # self.hor_kernel = self._build_kernel().transpose(0,1,3,2)

    def forward(self, x):
        '''
        x: imgs (Batch, H, W, 3)
        '''
        out = F.conv2d(x, self.kernel, stride=1, padding=2)
        out = self.truc(out)

        return out

    def _build_kernel(self, inc):
        # filter1: KB
        filter1 = [[0, 0, 0, 0, 0],
                   [0, -1, 2, -1, 0],
                   [0, 2, -4, 2, 0],
                   [0, -1, 2, -1, 0],
                   [0, 0, 0, 0, 0]]
        # filter2:KV
        filter2 = [[-1, 2, -2, 2, -1],
                   [2, -6, 8, -6, 2],
                   [-2, 8, -12, 8, -2],
                   [2, -6, 8, -6, 2],
                   [-1, 2, -2, 2, -1]]
        # filter3:hor 2rd
        filter3 = [[0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [0, 1, -2, 1, 0],
                  [0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0]]

        filter1 = np.asarray(filter1, dtype=float) / 4.
        filter2 = np.asarray(filter2, dtype=float) / 12.
        filter3 = np.asarray(filter3, dtype=float) / 2.
        # statck the filters
        filters = [[filter1],#, filter1, filter1],
                   [filter2],#, filter2, filter2],
                   [filter3]]#, filter3, filter3]]  # (3,3,5,5)
        filters = np.array(filters)
        filters = np.repeat(filters, inc, axis=1)
        filters = torch.FloatTensor(filters)    # (3,3,5,5)
        return filters

class SRMConv2d_Separate(nn.Module):
    
    def __init__(self, inc, outc, learnable=False):
        super(SRMConv2d_Separate, self).__init__()
        self.inc = inc
        self.truc = nn.Hardtanh(-3, 3)
        kernel = self._build_kernel(inc)  # (3,3,5,5)
        self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
        # self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
        self.out_conv = nn.Sequential(
            nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False),
            nn.BatchNorm2d(outc),
            nn.ReLU(inplace=True)
        )

        for ly in self.out_conv.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)

    def forward(self, x):
        '''
        x: imgs (Batch, H, W, 3)
        '''
        out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc)
        out = self.truc(out)
        out = self.out_conv(out)

        return out

    def _build_kernel(self, inc):
        # filter1: KB
        filter1 = [[0, 0, 0, 0, 0],
                   [0, -1, 2, -1, 0],
                   [0, 2, -4, 2, 0],
                   [0, -1, 2, -1, 0],
                   [0, 0, 0, 0, 0]]
        # filter2:KV
        filter2 = [[-1, 2, -2, 2, -1],
                   [2, -6, 8, -6, 2],
                   [-2, 8, -12, 8, -2],
                   [2, -6, 8, -6, 2],
                   [-1, 2, -2, 2, -1]]
        # # filter3:hor 2rd
        filter3 = [[0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [0, 1, -2, 1, 0],
                  [0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0]]

        filter1 = np.asarray(filter1, dtype=float) / 4.
        filter2 = np.asarray(filter2, dtype=float) / 12.
        filter3 = np.asarray(filter3, dtype=float) / 2.
        # statck the filters
        filters = [[filter1],#, filter1, filter1],
                   [filter2],#, filter2, filter2],
                   [filter3]]#, filter3, filter3]]  # (3,3,5,5)
        filters = np.array(filters)
        # filters = np.repeat(filters, inc, axis=1)
        filters = np.repeat(filters, inc, axis=0)
        filters = torch.FloatTensor(filters)    # (3,3,5,5)
        # print(filters.size())
        return filters
In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

"""
Channel Attention and Spaitial Attention from    
Woo, S., Park, J., Lee, J.Y., & Kweon, I. CBAM: Convolutional Block Attention Module. ECCV2018.
"""


class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.sharedMLP = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data, gain=0.02)

    def forward(self, x):
        avgout = self.sharedMLP(self.avg_pool(x))
        maxout = self.sharedMLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1

        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data, gain=0.02)

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avgout, maxout], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)


"""
The following modules are modified based on https://github.com/heykeetae/Self-Attention-GAN
"""


class Self_Attn(nn.Module):
    """ Self attention Layer"""

    def __init__(self, in_dim, out_dim=None, add=False, ratio=8):
        super(Self_Attn, self).__init__()
        self.chanel_in = in_dim
        self.add = add
        if out_dim is None:
            out_dim = in_dim
        self.out_dim = out_dim
        # self.activation = activation

        self.query_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.key_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.value_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=out_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(
            m_batchsize, -1, width*height).permute(0, 2, 1)  # B X C X(N)
        proj_key = self.key_conv(x).view(
            m_batchsize, -1, width*height)  # B X C x (*W*H)
        energy = torch.bmm(proj_query, proj_key)  # transpose check
        attention = self.softmax(energy)  # BX (N) X (N)
        proj_value = self.value_conv(x).view(
            m_batchsize, -1, width*height)  # B X C X N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, self.out_dim, width, height)

        if self.add:
            out = self.gamma*out + x
        else:
            out = self.gamma*out
        return out  # , attention


class CrossModalAttention(nn.Module):
    """ CMA attention Layer"""

    def __init__(self, in_dim, activation=None, ratio=8, cross_value=True):
        super(CrossModalAttention, self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        self.cross_value = cross_value

        self.query_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.key_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.value_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data, gain=0.02)

    def forward(self, x, y):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        B, C, H, W = x.size()

        proj_query = self.query_conv(x).view(
            B, -1, H*W).permute(0, 2, 1)  # B , HW, C
        proj_key = self.key_conv(y).view(
            B, -1, H*W)  # B X C x (*W*H)
        energy = torch.bmm(proj_query, proj_key)  # B, HW, HW
        attention = self.softmax(energy)  # BX (N) X (N)
        if self.cross_value:
            proj_value = self.value_conv(y).view(
                B, -1, H*W)  # B , C , HW
        else:
            proj_value = self.value_conv(x).view(
                B, -1, H*W)  # B , C , HW

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(B, C, H, W)

        out = self.gamma*out + x

        if self.activation is not None:
            out = self.activation(out)

        return out  # , attention


class DualCrossModalAttention(nn.Module):
    """ Dual CMA attention Layer"""

    def __init__(self, in_dim, activation=None, size=16, ratio=8, ret_att=False):
        super(DualCrossModalAttention, self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        self.ret_att = ret_att

        # query conv
        self.key_conv1 = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.key_conv2 = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//ratio, kernel_size=1)
        self.key_conv_share = nn.Conv2d(
            in_channels=in_dim//ratio, out_channels=in_dim//ratio, kernel_size=1)

        self.linear1 = nn.Linear(size*size, size*size)
        self.linear2 = nn.Linear(size*size, size*size)

        # separated value conv
        self.value_conv1 = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma1 = nn.Parameter(torch.zeros(1))

        self.value_conv2 = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma2 = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data, gain=0.02)
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data, gain=0.02)

    def forward(self, x, y):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        B, C, H, W = x.size()

        def _get_att(a, b):
            proj_key1 = self.key_conv_share(self.key_conv1(a)).view(
                B, -1, H*W).permute(0, 2, 1)  # B, HW, C
            proj_key2 = self.key_conv_share(self.key_conv2(b)).view(
                B, -1, H*W)  # B X C x (*W*H)
            energy = torch.bmm(proj_key1, proj_key2)  # B, HW, HW

            attention1 = self.softmax(self.linear1(energy))
            attention2 = self.softmax(self.linear2(
                energy.permute(0, 2, 1)))  # BX (N) X (N)

            return attention1, attention2

        att_y_on_x, att_x_on_y = _get_att(x, y)
        proj_value_y_on_x = self.value_conv2(y).view(
            B, -1, H*W)  # B, C, HW
        out_y_on_x = torch.bmm(proj_value_y_on_x, att_y_on_x.permute(0, 2, 1))
        out_y_on_x = out_y_on_x.view(B, C, H, W)
        out_x = self.gamma1*out_y_on_x + x

        proj_value_x_on_y = self.value_conv1(x).view(
            B, -1, H*W)  # B , C , HW
        out_x_on_y = torch.bmm(proj_value_x_on_y, att_x_on_y.permute(0, 2, 1))
        out_x_on_y = out_x_on_y.view(B, C, H, W)
        out_y = self.gamma2*out_x_on_y + y

        if self.ret_att:
            return out_x, out_y, att_y_on_x, att_x_on_y

        return out_x, out_y  # , attention


if __name__ == "__main__":
    x = torch.rand(10, 768, 16, 16)
    y = torch.rand(10, 768, 16, 16)
    dcma = DualCrossModalAttention(768, ret_att=True)
    out_x, out_y, att_y_on_x, att_x_on_y = dcma(x, y)
    print(out_y.size())
    print(att_x_on_y.size())
torch.Size([10, 768, 16, 16])
torch.Size([10, 256, 256])
In [3]:
"""
 Copyright (c) 2018 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter


class AngleSimpleLinear(nn.Module):
    """Computes cos of angles between input vectors and weights vectors"""
    def __init__(self, in_features, out_features):
        super(AngleSimpleLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)

    def forward(self, x):
        cos_theta = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
        return cos_theta.clamp(-1, 1)


def focal_loss(input_values, gamma):
    """Computes the focal loss"""
    p = torch.exp(-input_values)
    loss = (1 - p) ** gamma * input_values
    return loss.mean()


class AMSoftmaxLoss(nn.Module):
    """Computes the AM-Softmax loss with cos or arc margin"""
    margin_types = ['cos', 'arc']

    def __init__(self, margin_type='cos', gamma=0., m=0.5, s=30, t=1.):
        super(AMSoftmaxLoss, self).__init__()
        assert margin_type in AMSoftmaxLoss.margin_types
        self.margin_type = margin_type
        assert gamma >= 0
        self.gamma = gamma
        assert m > 0
        self.m = m
        assert s > 0
        self.s = s
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        assert t >= 1
        self.t = t

    def forward(self, cos_theta, target):
        if self.margin_type == 'cos':
            phi_theta = cos_theta - self.m
        else:
            sine = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
            phi_theta = cos_theta * self.cos_m - sine * self.sin_m #cos(theta+m)
            phi_theta = torch.where(cos_theta > self.th, phi_theta, cos_theta - self.sin_m * self.m)

        index = torch.zeros_like(cos_theta, dtype=torch.uint8)
        index.scatter_(1, target.data.view(-1, 1), 1)
        output = torch.where(index, phi_theta, cos_theta)

        if self.gamma == 0 and self.t == 1.:
            return F.cross_entropy(self.s*output, target)

        if self.t > 1:
            h_theta = self.t - 1 + self.t*cos_theta
            support_vecs_mask = (1 - index) * \
                torch.lt(torch.masked_select(phi_theta, index).view(-1, 1).repeat(1, h_theta.shape[1]) - cos_theta, 0)
            output = torch.where(support_vecs_mask, h_theta, output)
            return F.cross_entropy(self.s*output, target)

        return focal_loss(F.cross_entropy(self.s*output, target, reduction='none'), self.gamma)
In [5]:
"""
Code from https://github.com/ondyari/FaceForensics
Author: Andreas Rössler
"""
import os
import argparse


import torch
# import pretrainedmodels
import torch.nn as nn
import torch.nn.functional as F
# from lib.nets.xception import xception
import math
import torchvision

# import math
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init

pretrained_settings = {
    'xception': {
        'imagenet': {
            'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth',
            'input_space': 'RGB',
            'input_size': [3, 299, 299],
            'input_range': [0, 1],
            'mean': [0.5, 0.5, 0.5],
            'std': [0.5, 0.5, 0.5],
            'num_classes': 1000,
            'scale': 0.8975  # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
        }
    }
}

PRETAINED_WEIGHT_PATH = '/kaggle/input/xceptionb5690688pth/xception-b5690688.pth'

class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
        super(SeparableConv2d, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
                               stride, padding, dilation, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(
            in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):
    def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides != 1:
            self.skip = nn.Conv2d(in_filters, out_filters,
                                  1, stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip = None

        self.relu = nn.ReLU(inplace=True)
        rep = []

        filters = in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters, out_filters,
                                       3, stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters, filters,
                                       3, stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters, out_filters,
                                       3, stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3, strides, 1))
        self.rep = nn.Sequential(*rep)

    def forward(self, inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x += skip
        return x


def add_gaussian_noise(ins, mean=0, stddev=0.2):
    noise = ins.data.new(ins.size()).normal_(mean, stddev)
    return ins + noise


class Xception(nn.Module):
    """
    Xception optimized for the ImageNet dataset, as specified in
    https://arxiv.org/pdf/1610.02357.pdf
    """

    def __init__(self, num_classes=1000, inc=3):
        """ Constructor
        Args:
            num_classes: number of classes
        """
        super(Xception, self).__init__()
        self.num_classes = num_classes

        # Entry flow
        self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        # do relu here

        self.block1 = Block(
            64, 128, 2, 2, start_with_relu=False, grow_first=True)
        self.block2 = Block(
            128, 256, 2, 2, start_with_relu=True, grow_first=True)
        self.block3 = Block(
            256, 728, 2, 2, start_with_relu=True, grow_first=True)

        # middle flow
        self.block4 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block5 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block6 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block7 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)

        self.block8 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block9 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block10 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)
        self.block11 = Block(
            728, 728, 3, 1, start_with_relu=True, grow_first=True)

        # Exit flow
        self.block12 = Block(
            728, 1024, 2, 2, start_with_relu=True, grow_first=False)

        self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(1536)

        # do relu here
        self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
        self.bn4 = nn.BatchNorm2d(2048)

        self.fc = nn.Linear(2048, num_classes)

        # #------- init weights --------
        # for m in self.modules():
        #     if isinstance(m, nn.Conv2d):
        #         n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        #         m.weight.data.normal_(0, math.sqrt(2. / n))
        #     elif isinstance(m, nn.BatchNorm2d):
        #         m.weight.data.fill_(1)
        #         m.bias.data.zero_()
        # #-----------------------------
    def fea_part1_0(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        return x

    def fea_part1_1(self, x):

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

    def fea_part1(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

    def fea_part2(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        return x

    def fea_part3(self, x):
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)

        return x

    def fea_part4(self, x):
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)

        return x

    def fea_part5(self, x):
        x = self.block12(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)

        x = self.conv4(x)
        x = self.bn4(x)

        return x

    def features(self, input):
        x = self.fea_part1(input)

        x = self.fea_part2(x)
        x = self.fea_part3(x)
        x = self.fea_part4(x)

        x = self.fea_part5(x)
        return x

    def classifier(self, features):
        x = self.relu(features)

        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        out = self.last_linear(x)
        return out, x

    def forward(self, input):
        x = self.features(input)
        out, x = self.classifier(x)
        return out, x


def xception(num_classes=1000, pretrained='imagenet', inc=3):
    model = Xception(num_classes=num_classes, inc=inc)
    if pretrained:
        settings = pretrained_settings['xception'][pretrained]
        assert num_classes == settings['num_classes'], \
            "num_classes should be {}, but is {}".format(
                settings['num_classes'], num_classes)

        model = Xception(num_classes=num_classes)
        model.load_state_dict(model_zoo.load_url(settings['url']))

        model.input_space = settings['input_space']
        model.input_size = settings['input_size']
        model.input_range = settings['input_range']
        model.mean = settings['mean']
        model.std = settings['std']

    # TODO: ugly
    model.last_linear = model.fc
    del model.fc
    return model


class TransferModel(nn.Module):
    """
    Simple transfer learning model that takes an imagenet pretrained model with
    a fc layer as base model and retrains a new fc layer for num_out_classes
    """

    def __init__(self, modelchoice, num_out_classes=2, dropout=0.0,
                 weight_norm=False, return_fea=False, inc=3):
        super(TransferModel, self).__init__()
        self.modelchoice = modelchoice
        self.return_fea = return_fea

        if modelchoice == 'xception':

            def return_pytorch04_xception(pretrained=True):
                # Raises warning "src not broadcastable to dst" but thats fine
                model = xception(pretrained=False)
                if pretrained:
                    # Load model in torch 0.4+
                    model.fc = model.last_linear
                    del model.last_linear
                    state_dict = torch.load(
                        PRETAINED_WEIGHT_PATH)
                    for name, weights in state_dict.items():
                        if 'pointwise' in name:
                            state_dict[name] = weights.unsqueeze(
                                -1).unsqueeze(-1)
                    model.load_state_dict(state_dict)
                    model.last_linear = model.fc
                    del model.fc
                return model

            self.model = return_pytorch04_xception()
            # Replace fc
            num_ftrs = self.model.last_linear.in_features
            if not dropout:
                if weight_norm:
                    print('Using Weight_Norm')
                    self.model.last_linear = nn.utils.weight_norm(
                        nn.Linear(num_ftrs, num_out_classes), name='weight')
                self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
            else:
                print('Using dropout', dropout)
                if weight_norm:
                    print('Using Weight_Norm')
                    self.model.last_linear = nn.Sequential(
                        nn.Dropout(p=dropout),
                        nn.utils.weight_norm(
                            nn.Linear(num_ftrs, num_out_classes), name='weight')
                    )

                self.model.last_linear = nn.Sequential(
                    nn.Dropout(p=dropout),
                    nn.Linear(num_ftrs, num_out_classes)
                )

            if inc != 3:
                self.model.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
                nn.init.xavier_normal(self.model.conv1.weight.data, gain=0.02)

        elif modelchoice == 'resnet50' or modelchoice == 'resnet18':
            if modelchoice == 'resnet50':
                self.model = torchvision.models.resnet50(pretrained=True)
            if modelchoice == 'resnet18':
                self.model = torchvision.models.resnet18(pretrained=True)
            # Replace fc
            num_ftrs = self.model.fc.in_features
            if not dropout:
                self.model.fc = nn.Linear(num_ftrs, num_out_classes)
            else:
                self.model.fc = nn.Sequential(
                    nn.Dropout(p=dropout),
                    nn.Linear(num_ftrs, num_out_classes)
                )
        else:
            raise Exception('Choose valid model, e.g. resnet50')

    def set_trainable_up_to(self, boolean=False, layername="Conv2d_4a_3x3"):
        """
        Freezes all layers below a specific layer and sets the following layers
        to true if boolean else only the fully connected final layer
        :param boolean:
        :param layername: depends on lib, for inception e.g. Conv2d_4a_3x3
        :return:
        """
        # Stage-1: freeze all the layers
        if layername is None:
            for i, param in self.model.named_parameters():
                param.requires_grad = True
                return
        else:
            for i, param in self.model.named_parameters():
                param.requires_grad = False
        if boolean:
            # Make all layers following the layername layer trainable
            ct = []
            found = False
            for name, child in self.model.named_children():
                if layername in ct:
                    found = True
                    for params in child.parameters():
                        params.requires_grad = True
                ct.append(name)
            if not found:
                raise NotImplementedError('Layer not found, cant finetune!'.format(
                    layername))
        else:
            if self.modelchoice == 'xception':
                # Make fc trainable
                for param in self.model.last_linear.parameters():
                    param.requires_grad = True

            else:
                # Make fc trainable
                for param in self.model.fc.parameters():
                    param.requires_grad = True

    def forward(self, x):
        out, x = self.model(x)
        if self.return_fea:
            return out, x
        else:
            return out

    def features(self, x):
        x = self.model.features(x)
        return x

    def classifier(self, x):
        out, x = self.model.classifier(x)
        return out, x


def model_selection(modelname, num_out_classes,
                    dropout=None):
    """
    :param modelname:
    :return: model, image size, pretraining<yes/no>, input_list
    """
    if modelname == 'xception':
        return TransferModel(modelchoice='xception',
                             num_out_classes=num_out_classes), 299, \
            True, ['image'], None
    elif modelname == 'resnet18':
        return TransferModel(modelchoice='resnet18', dropout=dropout,
                             num_out_classes=num_out_classes), \
            224, True, ['image'], None
    else:
        raise NotImplementedError(modelname)


if __name__ == '__main__':
    model = TransferModel('xception', dropout=0.5)
    print(model)
    # model = model.cuda()
    # from torchsummary import summary
    # input_s = (3, image_size, image_size)
    # print(summary(model, input_s))
    dummy = torch.rand(10, 3, 256, 256)
    out = model(dummy)
    print(out.size())
    x = model.features(dummy)
    out, x = model.classifier(x)
    print(out.size())
    print(x.size())
Using dropout 0.5
TransferModel(
  (model): Xception(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (block1): Block(
      (skip): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (skipbn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): SeparableConv2d(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): SeparableConv2d(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
          (pointwise): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      )
    )
    (block2): Block(
      (skip): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (skipbn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): ReLU()
        (1): SeparableConv2d(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
          (pointwise): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): SeparableConv2d(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
          (pointwise): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      )
    )
    (block3): Block(
      (skip): Conv2d(256, 728, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (skipbn): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): ReLU()
        (1): SeparableConv2d(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
          (pointwise): Conv2d(256, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      )
    )
    (block4): Block(
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): ReLU()
        (1): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
        (7): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (block5): Block(
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): ReLU()
        (1): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
        (7): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (block6): Block(
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): ReLU()
        (1): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
        (7): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (block7): Block(
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): ReLU()
        (1): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
        (7): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (block8): Block(
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): ReLU()
        (1): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
        (7): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (block9): Block(
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): ReLU()
        (1): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
        (7): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (block10): Block(
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): ReLU()
        (1): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
        (7): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (block11): Block(
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): ReLU()
        (1): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
        (7): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (block12): Block(
      (skip): Conv2d(728, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (skipbn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): ReLU()
        (1): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): SeparableConv2d(
          (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
          (pointwise): Conv2d(728, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      )
    )
    (conv3): SeparableConv2d(
      (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)
      (pointwise): Conv2d(1024, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (bn3): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv4): SeparableConv2d(
      (conv1): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)
      (pointwise): Conv2d(1536, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (bn4): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (last_linear): Sequential(
      (0): Dropout(p=0.5, inplace=False)
      (1): Linear(in_features=2048, out_features=2, bias=True)
    )
  )
)
torch.Size([10, 2])
torch.Size([10, 2])
torch.Size([10, 2048])
In [87]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# from components.attention import ChannelAttention, SpatialAttention, DualCrossModalAttention
# from components.srm_conv import SRMConv2d_simple, SRMConv2d_Separate
# from networks.xception import TransferModel


class SRMPixelAttention(nn.Module):
    def __init__(self, in_channels):
        super(SRMPixelAttention, self).__init__()
        self.srm = SRMConv2d_simple()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, 2, 0, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        
        self.pa = SpatialAttention()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, a=1)
                if not m.bias is None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x_srm = self.srm(x)
        fea = self.conv(x_srm)        
        att_map = self.pa(fea)
        
        return att_map


class FeatureFusionModule(nn.Module):
    def __init__(self, in_chan=2048*2, out_chan=2048, *args, **kwargs):
        super(FeatureFusionModule, self).__init__()
        self.convblk = nn.Sequential(
            nn.Conv2d(in_chan, out_chan, 1, 1, 0, bias=False),
            nn.BatchNorm2d(out_chan),
            nn.ReLU()
        )
        self.ca = ChannelAttention(out_chan, ratio=16)
        self.init_weight()

    def forward(self, x, y):
        fuse_fea = self.convblk(torch.cat((x, y), dim=1))
        fuse_fea = fuse_fea + fuse_fea * self.ca(fuse_fea)
        return fuse_fea

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None:
                    nn.init.constant_(ly.bias, 0)


class Two_Stream_Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.xception_rgb = TransferModel(
            'xception', dropout=0.5, inc=3, return_fea=True)
        self.xception_srm = TransferModel(
            'xception', dropout=0.5, inc=3, return_fea=True)

        self.srm_conv0 = SRMConv2d_simple(inc=3)
        self.srm_conv1 = SRMConv2d_Separate(32, 32)
        self.srm_conv2 = SRMConv2d_Separate(64, 64)
        self.relu = nn.ReLU(inplace=True)

        self.att_map = None
        self.srm_sa = SRMPixelAttention(3)
        self.srm_sa_post = nn.Sequential(
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.dual_cma0 = DualCrossModalAttention(in_dim=728, ret_att=False)
        self.dual_cma1 = DualCrossModalAttention(in_dim=728, ret_att=False)

        self.fusion = FeatureFusionModule()

        self.att_dic = {}

    def features(self, x):
        srm = self.srm_conv0(x)

        x = self.xception_rgb.model.fea_part1_0(x)
        y = self.xception_srm.model.fea_part1_0(srm) \
            + self.srm_conv1(x)
        y = self.relu(y)

        x = self.xception_rgb.model.fea_part1_1(x)
        y = self.xception_srm.model.fea_part1_1(y) \
            + self.srm_conv2(x)
        y = self.relu(y)

        # srm guided spatial attention
        self.att_map = self.srm_sa(srm)
        x = x * self.att_map + x
        x = self.srm_sa_post(x)

        x = self.xception_rgb.model.fea_part2(x)
        y = self.xception_srm.model.fea_part2(y)

        x, y = self.dual_cma0(x, y)


        x = self.xception_rgb.model.fea_part3(x)        
        y = self.xception_srm.model.fea_part3(y)
 

        x, y = self.dual_cma1(x, y)

        x = self.xception_rgb.model.fea_part4(x)
        y = self.xception_srm.model.fea_part4(y)

        x = self.xception_rgb.model.fea_part5(x)
        y = self.xception_srm.model.fea_part5(y)

        fea = self.fusion(x, y)
                

        return fea

    def classifier(self, fea):
        out, fea = self.xception_rgb.classifier(fea)
        return out, fea

    def forward(self, x):
        '''
        x: original rgb
        '''
        out, fea = self.classifier(self.features(x))

#         return out, fea, self.att_map
        return out
    
if __name__ == '__main__':
    model = Two_Stream_Net()
    dummy = torch.rand((1,3,256,256))
    out = model(dummy)
    print(model)
    
Using dropout 0.5
Using dropout 0.5
Two_Stream_Net(
  (xception_rgb): TransferModel(
    (model): Xception(
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (block1): Block(
        (skip): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (skipbn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): SeparableConv2d(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
            (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): SeparableConv2d(
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
            (pointwise): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
      )
      (block2): Block(
        (skip): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (skipbn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
            (pointwise): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
            (pointwise): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
      )
      (block3): Block(
        (skip): Conv2d(256, 728, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (skipbn): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
            (pointwise): Conv2d(256, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
      )
      (block4): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block5): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block6): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block7): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block8): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block9): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block10): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block11): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block12): Block(
        (skip): Conv2d(728, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (skipbn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
      )
      (conv3): SeparableConv2d(
        (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)
        (pointwise): Conv2d(1024, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn3): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv4): SeparableConv2d(
        (conv1): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)
        (pointwise): Conv2d(1536, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn4): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (last_linear): Sequential(
        (0): Dropout(p=0.5, inplace=False)
        (1): Linear(in_features=2048, out_features=2, bias=True)
      )
    )
  )
  (xception_srm): TransferModel(
    (model): Xception(
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (block1): Block(
        (skip): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (skipbn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): SeparableConv2d(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
            (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): SeparableConv2d(
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
            (pointwise): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
      )
      (block2): Block(
        (skip): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (skipbn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
            (pointwise): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
            (pointwise): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
      )
      (block3): Block(
        (skip): Conv2d(256, 728, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (skipbn): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
            (pointwise): Conv2d(256, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
      )
      (block4): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block5): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block6): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block7): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block8): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block9): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block10): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block11): Block(
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (8): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (block12): Block(
        (skip): Conv2d(728, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (skipbn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (rep): Sequential(
          (0): ReLU()
          (1): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (2): BatchNorm2d(728, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): SeparableConv2d(
            (conv1): Conv2d(728, 728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=728, bias=False)
            (pointwise): Conv2d(728, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
      )
      (conv3): SeparableConv2d(
        (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)
        (pointwise): Conv2d(1024, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn3): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv4): SeparableConv2d(
        (conv1): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)
        (pointwise): Conv2d(1536, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn4): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (last_linear): Sequential(
        (0): Dropout(p=0.5, inplace=False)
        (1): Linear(in_features=2048, out_features=2, bias=True)
      )
    )
  )
  (srm_conv0): SRMConv2d_simple(
    (truc): Hardtanh(min_val=-3, max_val=3)
  )
  (srm_conv1): SRMConv2d_Separate(
    (truc): Hardtanh(min_val=-3, max_val=3)
    (out_conv): Sequential(
      (0): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (srm_conv2): SRMConv2d_Separate(
    (truc): Hardtanh(min_val=-3, max_val=3)
    (out_conv): Sequential(
      (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (relu): ReLU(inplace=True)
  (srm_sa): SRMPixelAttention(
    (srm): SRMConv2d_simple(
      (truc): Hardtanh(min_val=-3, max_val=3)
    )
    (conv): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (pa): SpatialAttention(
      (conv): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
      (sigmoid): Sigmoid()
    )
  )
  (srm_sa_post): Sequential(
    (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ReLU(inplace=True)
  )
  (dual_cma0): DualCrossModalAttention(
    (key_conv1): Conv2d(728, 91, kernel_size=(1, 1), stride=(1, 1))
    (key_conv2): Conv2d(728, 91, kernel_size=(1, 1), stride=(1, 1))
    (key_conv_share): Conv2d(91, 91, kernel_size=(1, 1), stride=(1, 1))
    (linear1): Linear(in_features=256, out_features=256, bias=True)
    (linear2): Linear(in_features=256, out_features=256, bias=True)
    (value_conv1): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1))
    (value_conv2): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1))
    (softmax): Softmax(dim=-1)
  )
  (dual_cma1): DualCrossModalAttention(
    (key_conv1): Conv2d(728, 91, kernel_size=(1, 1), stride=(1, 1))
    (key_conv2): Conv2d(728, 91, kernel_size=(1, 1), stride=(1, 1))
    (key_conv_share): Conv2d(91, 91, kernel_size=(1, 1), stride=(1, 1))
    (linear1): Linear(in_features=256, out_features=256, bias=True)
    (linear2): Linear(in_features=256, out_features=256, bias=True)
    (value_conv1): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1))
    (value_conv2): Conv2d(728, 728, kernel_size=(1, 1), stride=(1, 1))
    (softmax): Softmax(dim=-1)
  )
  (fusion): FeatureFusionModule(
    (convblk): Sequential(
      (0): Conv2d(4096, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (ca): ChannelAttention(
      (avg_pool): AdaptiveAvgPool2d(output_size=1)
      (max_pool): AdaptiveMaxPool2d(output_size=1)
      (sharedMLP): Sequential(
        (0): Conv2d(2048, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): ReLU()
        (2): Conv2d(128, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (sigmoid): Sigmoid()
    )
  )
)
In [49]:
import numpy as np # linear algebra
import pandas as pd
from glob import glob
# from retinaface import RetinaFace
import torch
from torch import optim
import torchvision
import timm
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import seaborn as sns
from PIL import Image
import random
import os
from torchvision.transforms import v2
from torch.utils.data import Dataset , DataLoader
import cv2
import matplotlib.pyplot as plt
import albumentations as A
from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
      RandomBrightnessContrast, Rotate, ShiftScaleRotate,  Transpose
    )
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import KFold
import torch.nn as nn
from contextlib import contextmanager
from torch.optim import Adam, SGD
from functools import partial
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
import time
from sklearn.metrics import roc_auc_score
import math
from catalyst.data import BalanceClassSampler
In [60]:
txt_to_csv = False
device = 'cuda' if torch.cuda.is_available() else 'cpu'
DIR_PATH = "/kaggle/input/deepfake/phase1"
TRAIN_DIR = "/kaggle/input/deepfake/phase1/trainset"
TEST_DIR = "/kaggle/input/deepfake/phase1/valset"
OUTPUT_DIR = "/kaggle/working/"
class CFG : 
    seed = 42
    n_fold = 5
    target_col = 'target'
    train=True
    inference=False
    pseudo_labeling = True
    num_classes = 2 #binary class
    trn_fold=[0, 1]
    debug=False
    apex=False
    print_freq=20 #every how many batch the scores get showed
    num_workers=4
#     model_name="eva02_large_patch14_448.mim_m38m_ft_in22k_in1k"
#     model_name=  "efficientnet_b3"
    size=256
    scheduler='CosineAnnealingWarmRestarts' 
    epochs=2
    lr=1e-4
    min_lr=1e-6
    T_0=10 # CosineAnnealingWarmRestarts
    batch_size=20
    weight_decay=1e-6
    gradient_accumulation_steps=1
    max_grad_norm=1000
In [51]:
train = pd.read_csv(f"{DIR_PATH +'/trainset_label.txt'}")
test = pd.read_csv(f"{DIR_PATH +'/valset_label.txt'}")
In [52]:
if CFG.pseudo_labeling : 
    ps = pd.read_csv('/kaggle/input/pseudolabling/b4_nTTA.csv')
    ps.rename(columns = {"label" : "target"} , inplace = True)
    to_add = ps[(ps['target']>0.9) | (ps['target']<0.1)]
#     print(to_add.shape)
    to_add["target"] = [1 if i>0.9 else 0 for i in to_add['target']]
    print(to_add["target"].value_counts())
    shape_before = train.shape
    train = pd.concat([train , to_add] , axis=0)
    shape_after = train.shape
    print(f"The shape of the train set have moved from {shape_before} => {shape_after}")
    train.reset_index(drop = True , inplace =True , )
target
1    87148
0    57023
Name: count, dtype: int64
The shape of the train set have moved from (524429, 2) => (668600, 2)
In [53]:
from sklearn.metrics import log_loss

def get_score(y_true, y_pred):
    num_classes = 2 
    total_log_loss = 0.0
    y_true = np.array([[0, 1] if i == 1 else [1, 0] for i in y_true])
#     print(y_true)
#     print(y_pred)
    for class_idx in range(num_classes):
        
        class_true = y_true[:,class_idx]  
        class_pred = y_pred[:, class_idx] 

        class_log_loss = log_loss(class_true, class_pred)
        total_log_loss += class_log_loss
    return total_log_loss
#     mean_log_loss = total_log_loss / num_classes
#     return mean_log_loss

# def get_score(y_true, y_pred):
#     # Ensure y_true and y_pred are 1D arrays
#     y_true = y_true.flatten()
#     y_pred = y_pred.flatten()

#     # Calculate the log loss directly
#     total_log_loss = log_loss(y_true, y_pred)



@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')


def init_logger(log_file=OUTPUT_DIR+'train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()



def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed=CFG.seed)
In [54]:
if CFG.debug:
    CFG.epochs = 1
    train = train.sample(n=10000, random_state=CFG.seed).reset_index(drop=True)
    test = test.sample(n=1000, random_state=CFG.seed).reset_index(drop=True)
In [55]:
files = glob(DIR_PATH+"/valset/*")
In [56]:
def len_txt(txt_file_path):
    with open(txt_file_path) as f:
        line_count = 0
        for line in f:
            line_count += 1
    return line_count
In [57]:
print(f"The train file contains {len_txt(DIR_PATH +'/trainset_label.txt')} elements")
print(f"The test file contains {len_txt(DIR_PATH +'/valset_label.txt')} elements")
The train file contains 524430 elements
The test file contains 147364 elements
In [58]:
# tkhalwidh
if txt_to_csv : 
    
    with open(DIR_PATH+"/trainset_label.txt") as f : 
        counter = 0
        for line in tqdm(f , desc = "Collecting train set") : 

            if counter >= 1 : 
                l = line.strip().split(",")
                new_row = {"img_name": l[0] , "target": l[1]}
                train.loc[len(train)] = new_row
            counter +=1
        
    with open(DIR_PATH+"/valset_label.txt") as f : 
        counter = 0
        for line in tqdm(f , desc = "Collecting test set") : 

            if counter >= 1 : 
                l = line.strip().split(",")
                new_row = {"img_name": l[0] , "target": l[1]}
                test.loc[len(test)] = new_row
            counter +=1
In [59]:
sns.countplot(data = train , x = train["target"])
Out[59]:
<Axes: xlabel='target', ylabel='count'>
In [73]:
class TrainDataset(Dataset) : 
    def __init__(self , df , transform = None) : 
        self.df = df 
        self.transform = transform
        self.file_names = df["img_name"].values
        self.labels = df["target"].values

        
        
    def __len__(self) : 
        return len(self.df)
    
    def __getitem__(self, idx):
        file_name = self.file_names[idx]

        # Check if the file is in the TRAIN_DIR or TEST_DIR
        file_path_train = f'{TRAIN_DIR}/{file_name}'
        file_path_test = f'{TEST_DIR}/{file_name}'

        if os.path.exists(file_path_train):
            file_path = file_path_train
        elif os.path.exists(file_path_test):
            file_path = file_path_test
        else:
            raise FileNotFoundError(f'File {file_name} not found in either TRAIN_DIR or TEST_DIR')

        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        
        if self.transform:
            
            augmented = self.transform(image=image)
            image = augmented['image']

        label = torch.tensor(self.labels[idx]).long()
        return image, label
    
    def get_labels(self):
        return list(self.labels)
    
class TestDataset(Dataset) : 
    def __init__(self , df , transform = None) : 
        self.df = df 
        self.transform = transform
        self.file_names = df["img_name"].values
        
    def __len__(self) : 
        return len(self.df)
    
    def __getitem__(self , idx) : 
        
        file_name = self.file_names[idx]
        file_path = f'{TEST_DIR}/{file_name}' 
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image , cv2.COLOR_BGR2RGB)
        if self.transform : 
            augmented = self.transform(image=image)
            image = augmented['image']
            
        return image 
In [74]:
train_dataset = TrainDataset(train)
In [75]:
fig, axes = plt.subplots(2, 4, figsize=(10, 7))

for i in range(2):
    for j in range(4):
        index = i * 3 + j
        if index < len(train_dataset):
            image, label = train_dataset[index]
            axes[i, j].imshow(image)
            if label.numpy() == 1:
                axes[i, j].set_title("Fake", color="r")
            else:
                axes[i, j].set_title("Real", color="g")
            axes[i, j].axis('off')

plt.tight_layout()
plt.show()
In [76]:
from albumentations import Compose, RandomBrightnessContrast, RandomCrop, \
    HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, ISONoise, MultiplicativeNoise, CoarseDropout, MedianBlur, Blur, GlassBlur, MotionBlur, \
    ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur, ToSepia, RandomShadow, RandomGamma, Rotate, Resize
from albumentations import RandomBrightnessContrast 
from PIL import Image
# from transforms.albu import IsotropicResize, FFT, SR, DCT, CustomRandomCrop
import cv2
import numpy as np
import os 
import imageio

import random

import cv2
import numpy as np
import torch
from albumentations import DualTransform, ImageOnlyTransform
from albumentations.augmentations.crops.transforms import Crop


from skimage.color import rgb2hsv, rgb2gray, rgb2yuv
from skimage import color, exposure, transform
from skimage.exposure import equalize_hist
from albumentations import RandomCrop
from scipy.fftpack import dct, idct

def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
    h, w = img.shape[:2]

    if max(w, h) == size:
        return img
    if w > h:
        scale = size / w
        h = h * scale
        w = size
    else:
        scale = size / h
        w = w * scale
        h = size
    interpolation = interpolation_up if scale > 1 else interpolation_down

    img = img.astype('uint8')
    resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
    return resized


class IsotropicResize(DualTransform):
    def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
                 always_apply=False, p=1):
        super(IsotropicResize, self).__init__(always_apply, p)
        self.max_side = max_side
        self.interpolation_down = interpolation_down
        self.interpolation_up = interpolation_up

    def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
        return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
                                          interpolation_up=interpolation_up)

    def apply_to_mask(self, img, **params):
        return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)

    def get_transform_init_args_names(self):
        return ("max_side", "interpolation_down", "interpolation_up")


class Resize4xAndBack(ImageOnlyTransform):
    def __init__(self, always_apply=False, p=0.5):
        super(Resize4xAndBack, self).__init__(always_apply, p)

    def apply(self, img, **params):
        h, w = img.shape[:2]
        scale = random.choice([2, 4])
        img = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_AREA)
        img = cv2.resize(img, (w, h),
                         interpolation=random.choice([cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST]))
        return img


class RandomSizedCropNonEmptyMaskIfExists(DualTransform):

    def __init__(self, min_max_height, w2h_ratio=[0.7, 1.3], always_apply=False, p=0.5):
        super(RandomSizedCropNonEmptyMaskIfExists, self).__init__(always_apply, p)

        self.min_max_height = min_max_height
        self.w2h_ratio = w2h_ratio

    def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
        cropped = crop(img, x_min, y_min, x_max, y_max)
        return cropped

    @property
    def targets_as_params(self):
        return ["mask"]

    def get_params_dependent_on_targets(self, params):
        mask = params["mask"]
        mask_height, mask_width = mask.shape[:2]
        crop_height = int(mask_height * random.uniform(self.min_max_height[0], self.min_max_height[1]))
        w2h_ratio = random.uniform(*self.w2h_ratio)
        crop_width = min(int(crop_height * w2h_ratio), mask_width - 1)
        if mask.sum() == 0:
            x_min = random.randint(0, mask_width - crop_width + 1)
            y_min = random.randint(0, mask_height - crop_height + 1)
        else:
            mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
            non_zero_yx = np.argwhere(mask)
            y, x = random.choice(non_zero_yx)
            x_min = x - random.randint(0, crop_width - 1)
            y_min = y - random.randint(0, crop_height - 1)
            x_min = np.clip(x_min, 0, mask_width - crop_width)
            y_min = np.clip(y_min, 0, mask_height - crop_height)

        x_max = x_min + crop_height
        y_max = y_min + crop_width
        y_max = min(mask_height, y_max)
        x_max = min(mask_width, x_max)
        return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}

    def get_transform_init_args_names(self):
        return "min_max_height", "height", "width", "w2h_ratio"

class CustomRandomCrop(DualTransform):
    def __init__(self, size, p=0.5) -> None:
        super(CustomRandomCrop, self).__init__(p=p)
        self.size = size
        self.prob = p

    def apply(self, img, copy=True, **params):
        if img.shape[0] < self.size or img.shape[1] < self.size:
            transform = IsotropicResize(max_side=self.size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR)
        else:
            transform = RandomCrop(self.size, self.size)
        return np.asarray(transform(image=img)["image"])

class FFT(DualTransform):
    def __init__(self, mode, p=0.5) -> None:
        super(FFT, self).__init__(p=p)
        self.prob = p
        self.mode = mode

    def apply(self, img, copy=True, **params):
        dark_image_grey_fourier = np.fft.fftshift(np.fft.fft2(rgb2gray(img)))
        mask = np.log(abs(dark_image_grey_fourier)).astype(np.uint8)
        mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
        if self.mode == 0:
            return np.asarray(cv2.bitwise_and(img, img, mask=mask))
        else:
            mask = np.asarray(mask)
            image =  cv2.merge((mask, mask, mask))
            return image

class SR(DualTransform):
    def __init__(self, model_sr, p=0.5) -> None:
        super(SR, self).__init__(p=p)
        self.prob = p
        self.model_sr = model_sr

    def apply(self, img, copy=True, **params):
        img = cv2.resize(img, (int(img.shape[1]/2), int(img.shape[0]/2)), interpolation = cv2.INTER_AREA)
        img = np.transpose(img, (2, 0, 1))
        img = torch.tensor(img, dtype=torch.float).unsqueeze(0).to(2)
        sr_img = self.model_sr(img)
        return sr_img.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()


class DCT(DualTransform):
    def __init__(self, mode, p=0.5) -> None:
        super(DCT, self).__init__(p=p)
        self.prob = p
        self.mode = mode

    def rgb2gray(self, rgb):
        return cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY)

    def apply(self, img, copy=True, **params):
        gray_img = self.rgb2gray(img)
        dct_coefficients = cv2.dct(cv2.dct(np.float32(gray_img), flags=cv2.DCT_ROWS), flags=cv2.DCT_ROWS)
        epsilon = 1
        mask = np.log(np.abs(dct_coefficients) + epsilon).astype(np.uint8)
        mask = cv2.resize(mask, (img.shape[1], img.shape[0]))


        if self.mode == 0:
            return cv2.bitwise_and(img, img, mask=mask)
        else:
            dct_coefficients = np.asarray(dct_coefficients)
            image = cv2.merge((dct_coefficients, dct_coefficients, dct_coefficients))
            return image
In [77]:
import albumentations as A
def get_transforms(* , data) : 
    size = CFG.size

    if data == 'train':
        return Compose([
        ImageCompression(quality_lower=40, quality_upper=100, p=0.1),
        HorizontalFlip(),
        GaussNoise(p=0.3),
        ISONoise(p=0.3),
        MultiplicativeNoise(p=0.3),
        OneOf([
            IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
            IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
            IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
            CustomRandomCrop(size=size)
        ], p=1),
        Resize(height=size, width=size),
        PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT , value=0 , p=1),
        OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.5),
        OneOf([CoarseDropout()], p=0.05),
        ToGray(p=0.1),
        ToSepia(p=0.05),
        RandomShadow(p=0.05),
        RandomGamma(p=0.1),
        ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
        FFT(mode=0, p=0.05),
        DCT(mode=1, p=0.5) ,
        Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
        ToTensorV2(),
    ])
    elif data == 'valid':
        return Compose([
            IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
            Resize(CFG.size, CFG.size),
#             PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT , value=0 ),
            
            
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
In [78]:
train_dataset = TrainDataset(train , transform= get_transforms(data = "train"))
In [79]:
fig, axes = plt.subplots(2, 4, figsize=(10, 7))

for i in range(2):
    for j in range(4):
        index = i * 3 + j
        if index < len(train_dataset):
            image, label = train_dataset[index]
            axes[i, j].imshow(image.permute(1,2,0))
            if label.numpy() == 1:
                axes[i, j].set_title("Fake", color="r")
            else:
                axes[i, j].set_title("Real", color="g")
            axes[i, j].axis('off')

plt.tight_layout()
plt.show()
In [80]:
folds = train.copy()
Fold = KFold(n_splits = CFG.n_fold  , shuffle = True , random_state = CFG.seed)
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds[CFG.target_col])):
    folds.loc[val_index, 'fold'] = int(n)
folds['fold'] = folds['fold'].astype(int)
In [88]:
model = Two_Stream_Net()
model(train_dataset[0][0].unsqueeze(1).permute(1,0,2,3))
Using dropout 0.5
Using dropout 0.5
Out[88]:
tensor([[-1.3592,  0.0468]], grad_fn=<AddmmBackward0>)
In [83]:
import wandb

try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("wandb_api")
    wandb.login(key=api_key)
    anonymous = None
except:
    anonymous = "must"
    print('To use your W&B account,\nGo to Add-ons -> Secrets and provide your W&B access token. Use the Label name as WANDB. \nGet your W&B access token from here: https://wandb.ai/authorize')
wandb: W&B API key is configured. Use `wandb login --relogin` to force relogin
wandb: WARNING If you're specifying your api key in code, ensure this code is not shared publicly.
wandb: WARNING Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.
wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
In [84]:
run = wandb.init(entity = 'lassouedaymenla',
                 project = 'tutorial',
                 save_code = True,
                 name = "FaceForgery"
)
wandb: Currently logged in as: lassouedaymenla. Use `wandb login --relogin` to force relogin
wandb version 0.17.5 is available! To upgrade, please run: $ pip install wandb --upgrade
Tracking run with wandb version 0.17.4
Run data is saved locally in /kaggle/working/wandb/run-20240722_172435-gi80kvwo
Syncing run FaceForgery to Weights & Biases (docs)
View project at https://wandb.ai/lassouedaymenla/tutorial
View run at https://wandb.ai/lassouedaymenla/tutorial/runs/gi80kvwo
In [85]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    global_step = 0
    for step, (images, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        y_preds = model(images)
        
        labels = labels.cuda()
        y_preds = y_preds.cuda()
# debug
#         print(torch.nn.functional.softmax(y_preds, dim=1))
#         print(labels)
        loss = criterion(y_preds, labels)
        # record loss
        losses.update(loss.item(), batch_size)
        
#         # Logging to wandb
#         wandb.log({"Training Loss": loss.item(), "Epoch": epoch, "Step": global_step})
        
        
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        if CFG.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'Grad: {grad_norm:.4f}  '
                  #'LR: {lr:.6f}  '
                  .format(
                   epoch+1, step, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses,
                   remain=timeSince(start, float(step+1)/len(train_loader)),
                   grad_norm=grad_norm,
                   #lr=scheduler.get_lr()[0],
                   ))
            
#     # Log epoch summary to wandb
#     wandb.log({"Epoch Training Loss": losses.avg, "Epoch": epoch})

        wandb.log({
    "Train Loss": losses.val,
    "Step": step,
    "Gradient Norm": grad_norm,
    "Learning Rate": optimizer.param_groups[0]['lr']  # Add this line to log the learning rate
})
    return losses.avg


def valid_fn(valid_loader, model, criterion, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to evaluation mode
    model.eval()
    preds = []
    start = end = time.time()
    for step, (images, labels) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        # compute loss
        with torch.no_grad():
            y_preds = model(images)
            
        labels = labels.cuda()
        y_preds = y_preds.cuda()
        
        loss = criterion(y_preds, labels)
        losses.update(loss.item(), batch_size)
        
        y_preds = torch.nn.functional.softmax(y_preds, dim=1)
        # record accuracy
        y_preds = y_preds.to('cpu').numpy()
       
       

        
        preds.append(y_preds)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  .format(
                   step, len(valid_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses,
                   remain=timeSince(start, float(step+1)/len(valid_loader)),
                   ))
        wandb.log({
            "Val Loss ": losses.val,
            "Val Step": step , 
            
        })           
    predictions = np.concatenate(preds)
    return losses.avg, predictions



def inference(model, states, test_loader, device):
    model.to(device)
    tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    probs = []
    for i, (images) in tk0:
        images = images.to(device)
        avg_preds = []
        for state in states:
            model.load_state_dict(state['model'])
            model.eval()
            with torch.no_grad():
#                 print(images.shape)
                y_preds = model(images)
            avg_preds.append(y_preds.to('cpu').numpy())
        avg_preds = np.mean(avg_preds, axis=0)
        probs.append(avg_preds)
    probs = np.concatenate(probs)
    return probs
In [89]:
def train_loop(folds, fold):

    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)

    train_dataset = TrainDataset(train_folds, 
                                 transform=get_transforms(data='train'))
    valid_dataset = TrainDataset(valid_folds, 
                                 transform=get_transforms(data='valid'))

    train_loader = DataLoader(train_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=False, 
                              num_workers=CFG.num_workers, sampler=BalanceClassSampler(labels=train_dataset.get_labels(), mode="upsampling") , 
                              pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=False, 
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    
    # ====================================================
    # scheduler 
    # ====================================================
    def get_scheduler(optimizer):
        if CFG.scheduler=='ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
        elif CFG.scheduler=='CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
        elif CFG.scheduler=='CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
        return scheduler

    # ====================================================
    # model & optimizer
    # ====================================================
    model = Two_Stream_Net()
    model.to(device)

    optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False)
    scheduler = get_scheduler(optimizer)

    # ====================================================
    # apex
    # ====================================================
    if CFG.apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

    # ====================================================
    # loop
    # ====================================================
    criterion = nn.CrossEntropyLoss().cuda()


    best_score = 50000
    best_loss = np.inf
    
    for epoch in range(CFG.epochs):
        
        start_time = time.time()
        
        # train
        avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device)

        # eval
        avg_val_loss, preds = valid_fn(valid_loader, model, criterion, device)
        valid_labels = valid_folds[CFG.target_col].values
        
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()

        # scoring
        score = get_score(valid_labels, preds)
        print(score)
        preds= torch.nn.functional.softmax(torch.from_numpy(preds), dim=1).numpy()[:,1]
        score2 = roc_auc_score(valid_labels, preds)

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s') #.info makes the msg shows in red cadre
        LOGGER.info(f'Epoch {epoch+1} - LogLoss: {score} - AUC: {score2}')

        if score < best_score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict(), 
                        'preds': preds},
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')
    
    check_point = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')
    #valid_folds[[str(c) for c in range(5)]] = check_point['preds']
    #valid_folds['preds'] = check_point['preds'].argmax(1)

    return 
In [90]:
def main():

    """
    Prepare: 1.train  2.test  3.submission  4.folds
    """

    def get_result(result_df):
        preds = result_df['preds'].values
        labels = result_df[CFG.target_col].values
        score = get_score(labels, preds)
        LOGGER.info(f'Score: {score:<.5f}')
    
    if CFG.train:
        # train 
        oof_df = pd.DataFrame()
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                train_loop(folds, fold)
                #oof_df = pd.concat([oof_df, _oof_df])
                #LOGGER.info(f"========== fold: {fold} result ==========")
                #get_result(_oof_df)
        # CV result
        LOGGER.info(f"========== CV ==========")
        #get_result(oof_df)
        # save result
        #oof_df.to_csv(OUTPUT_DIR+'oof_df.csv', index=False)
    
    if CFG.inference:
        # inference
        model = CustomResNext(CFG.model_name, pretrained=False)
        states = [torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth') for fold in CFG.trn_fold]
        test_dataset = TestDataset(test, transform=get_transforms(data='valid'))
        test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False, 
                                 num_workers=CFG.num_workers, pin_memory=True)
        predictions = inference(model, states, test_loader, device)
        # submission
        print(predictions)
        test['label'] = torch.nn.functional.softmax(torch.from_numpy(predictions), dim=1).numpy()[:,1]
        print(test['label'])
        test[['img_name', 'label']].to_csv(OUTPUT_DIR+'submission.csv', index=False)
In [ ]:
if __name__ == '__main__':                                                                                              
    main()
========== fold: 0 training ==========
Using dropout 0.5
Using dropout 0.5
Epoch: [1][0/40998] Data 1.775 (1.775) Elapsed 0m 4s (remain 3105m 11s) Loss: 0.8732(0.8732) Grad: 13.5833  
Epoch: [1][20/40998] Data 0.552 (0.585) Elapsed 0m 21s (remain 684m 25s) Loss: 0.7016(0.8085) Grad: 8.5844  
Epoch: [1][40/40998] Data 0.552 (0.569) Elapsed 0m 38s (remain 634m 19s) Loss: 0.8539(0.7425) Grad: 10.9534  
Epoch: [1][60/40998] Data 0.553 (0.564) Elapsed 0m 55s (remain 616m 52s) Loss: 0.6197(0.7149) Grad: 7.0083  
Epoch: [1][80/40998] Data 0.553 (0.561) Elapsed 1m 12s (remain 607m 52s) Loss: 0.5332(0.6929) Grad: 7.3943  
Epoch: [1][100/40998] Data 0.553 (0.559) Elapsed 1m 29s (remain 602m 20s) Loss: 0.5192(0.6752) Grad: 6.0233  
Epoch: [1][120/40998] Data 0.552 (0.558) Elapsed 1m 46s (remain 598m 32s) Loss: 0.6919(0.6608) Grad: 11.2926  
Epoch: [1][140/40998] Data 0.553 (0.558) Elapsed 2m 3s (remain 595m 44s) Loss: 0.5706(0.6463) Grad: 6.2257  
Epoch: [1][160/40998] Data 0.553 (0.557) Elapsed 2m 20s (remain 593m 34s) Loss: 0.7004(0.6316) Grad: 7.9941  
Epoch: [1][180/40998] Data 0.552 (0.556) Elapsed 2m 37s (remain 591m 49s) Loss: 0.5212(0.6236) Grad: 4.7840  
Epoch: [1][200/40998] Data 0.553 (0.556) Elapsed 2m 54s (remain 590m 21s) Loss: 0.6489(0.6154) Grad: 6.1323  
Epoch: [1][220/40998] Data 0.552 (0.556) Elapsed 3m 11s (remain 589m 6s) Loss: 0.5790(0.6054) Grad: 7.4238  
Epoch: [1][240/40998] Data 0.551 (0.556) Elapsed 3m 28s (remain 588m 0s) Loss: 0.6314(0.5980) Grad: 7.0559  
Epoch: [1][260/40998] Data 0.553 (0.555) Elapsed 3m 45s (remain 587m 2s) Loss: 0.4474(0.5893) Grad: 4.5827  
Epoch: [1][280/40998] Data 0.552 (0.555) Elapsed 4m 2s (remain 586m 10s) Loss: 0.2868(0.5788) Grad: 3.7827  
Epoch: [1][300/40998] Data 0.553 (0.555) Elapsed 4m 19s (remain 585m 22s) Loss: 0.3590(0.5748) Grad: 3.8241  
Epoch: [1][320/40998] Data 0.553 (0.555) Elapsed 4m 36s (remain 584m 39s) Loss: 0.4217(0.5700) Grad: 4.3578  
Epoch: [1][340/40998] Data 0.553 (0.555) Elapsed 4m 53s (remain 583m 58s) Loss: 0.5672(0.5651) Grad: 5.0714  
Epoch: [1][360/40998] Data 0.553 (0.555) Elapsed 5m 10s (remain 583m 20s) Loss: 0.5896(0.5608) Grad: 6.0533  
Epoch: [1][380/40998] Data 0.552 (0.554) Elapsed 5m 27s (remain 582m 44s) Loss: 0.6467(0.5573) Grad: 6.7685  
Epoch: [1][400/40998] Data 0.553 (0.554) Elapsed 5m 45s (remain 582m 10s) Loss: 0.2545(0.5520) Grad: 4.3219  
Epoch: [1][420/40998] Data 0.553 (0.554) Elapsed 6m 2s (remain 581m 38s) Loss: 0.4342(0.5507) Grad: 3.2260  
Epoch: [1][440/40998] Data 0.552 (0.554) Elapsed 6m 19s (remain 581m 8s) Loss: 0.5479(0.5483) Grad: 5.4349  
Epoch: [1][460/40998] Data 0.553 (0.554) Elapsed 6m 36s (remain 580m 37s) Loss: 0.4718(0.5442) Grad: 4.3204  
Epoch: [1][480/40998] Data 0.553 (0.554) Elapsed 6m 53s (remain 580m 9s) Loss: 0.4894(0.5410) Grad: 4.7152  
Epoch: [1][500/40998] Data 0.552 (0.554) Elapsed 7m 10s (remain 579m 41s) Loss: 0.4900(0.5373) Grad: 4.7308  
Epoch: [1][520/40998] Data 0.552 (0.554) Elapsed 7m 27s (remain 579m 15s) Loss: 0.2977(0.5369) Grad: 2.6538  
Epoch: [1][540/40998] Data 0.553 (0.554) Elapsed 7m 44s (remain 578m 48s) Loss: 0.3365(0.5333) Grad: 3.6212  
Epoch: [1][560/40998] Data 0.553 (0.554) Elapsed 8m 1s (remain 578m 23s) Loss: 0.7219(0.5302) Grad: 8.6209  
Epoch: [1][580/40998] Data 0.553 (0.554) Elapsed 8m 18s (remain 577m 58s) Loss: 0.4107(0.5271) Grad: 3.9289  
Epoch: [1][600/40998] Data 0.553 (0.554) Elapsed 8m 35s (remain 577m 33s) Loss: 0.4708(0.5250) Grad: 4.1229  
Epoch: [1][620/40998] Data 0.553 (0.554) Elapsed 8m 52s (remain 577m 9s) Loss: 0.4260(0.5233) Grad: 4.5030  
Epoch: [1][640/40998] Data 0.552 (0.554) Elapsed 9m 9s (remain 576m 45s) Loss: 0.7291(0.5207) Grad: 6.3236  
Epoch: [1][660/40998] Data 0.553 (0.554) Elapsed 9m 26s (remain 576m 22s) Loss: 0.4016(0.5187) Grad: 3.1535  
Epoch: [1][680/40998] Data 0.553 (0.554) Elapsed 9m 43s (remain 575m 59s) Loss: 0.3967(0.5166) Grad: 3.5701  
Epoch: [1][700/40998] Data 0.553 (0.554) Elapsed 10m 0s (remain 575m 37s) Loss: 0.4938(0.5147) Grad: 5.0574  
Epoch: [1][720/40998] Data 0.552 (0.554) Elapsed 10m 17s (remain 575m 15s) Loss: 0.4573(0.5133) Grad: 4.9107  
Epoch: [1][740/40998] Data 0.553 (0.554) Elapsed 10m 34s (remain 574m 53s) Loss: 0.3223(0.5107) Grad: 3.1092  
Epoch: [1][760/40998] Data 0.553 (0.554) Elapsed 10m 51s (remain 574m 32s) Loss: 0.2485(0.5070) Grad: 3.5340  
Epoch: [1][780/40998] Data 0.553 (0.554) Elapsed 11m 9s (remain 574m 10s) Loss: 0.4243(0.5047) Grad: 5.0002  
Epoch: [1][800/40998] Data 0.553 (0.554) Elapsed 11m 26s (remain 573m 49s) Loss: 0.5623(0.5030) Grad: 5.6730  
Epoch: [1][820/40998] Data 0.553 (0.554) Elapsed 11m 43s (remain 573m 28s) Loss: 0.3768(0.5005) Grad: 3.8652  
Epoch: [1][840/40998] Data 0.553 (0.554) Elapsed 12m 0s (remain 573m 8s) Loss: 0.3171(0.4984) Grad: 4.0671  
Epoch: [1][860/40998] Data 0.553 (0.554) Elapsed 12m 17s (remain 572m 47s) Loss: 0.5697(0.4968) Grad: 5.0413  
Epoch: [1][880/40998] Data 0.553 (0.554) Elapsed 12m 34s (remain 572m 27s) Loss: 0.2385(0.4948) Grad: 3.0250  
Epoch: [1][900/40998] Data 0.553 (0.554) Elapsed 12m 51s (remain 572m 6s) Loss: 0.3619(0.4934) Grad: 4.8707  
Epoch: [1][920/40998] Data 0.553 (0.554) Elapsed 13m 8s (remain 571m 46s) Loss: 0.4019(0.4920) Grad: 2.8238  
Epoch: [1][940/40998] Data 0.552 (0.553) Elapsed 13m 25s (remain 571m 29s) Loss: 0.6209(0.4903) Grad: 5.1697  
Epoch: [1][960/40998] Data 0.552 (0.553) Elapsed 13m 42s (remain 571m 9s) Loss: 0.4208(0.4884) Grad: 3.7733  
Epoch: [1][980/40998] Data 0.553 (0.553) Elapsed 13m 59s (remain 570m 49s) Loss: 0.4393(0.4873) Grad: 3.6124  
Epoch: [1][1000/40998] Data 0.552 (0.553) Elapsed 14m 16s (remain 570m 30s) Loss: 0.4038(0.4851) Grad: 4.7643  
Epoch: [1][1020/40998] Data 0.553 (0.553) Elapsed 14m 33s (remain 570m 11s) Loss: 0.3471(0.4838) Grad: 3.6725  
Epoch: [1][1040/40998] Data 0.552 (0.553) Elapsed 14m 50s (remain 569m 52s) Loss: 0.4080(0.4822) Grad: 4.1811  
Epoch: [1][1060/40998] Data 0.552 (0.553) Elapsed 15m 7s (remain 569m 32s) Loss: 0.4326(0.4809) Grad: 4.6237  
Epoch: [1][1080/40998] Data 0.553 (0.553) Elapsed 15m 24s (remain 569m 13s) Loss: 0.3635(0.4792) Grad: 3.5998  
Epoch: [1][1100/40998] Data 0.553 (0.553) Elapsed 15m 41s (remain 568m 54s) Loss: 0.2735(0.4782) Grad: 2.8812  
Epoch: [1][1120/40998] Data 0.553 (0.553) Elapsed 15m 59s (remain 568m 36s) Loss: 0.2919(0.4769) Grad: 4.5212  
Epoch: [1][1140/40998] Data 0.552 (0.553) Elapsed 16m 16s (remain 568m 17s) Loss: 0.4960(0.4759) Grad: 4.6030  
Epoch: [1][1160/40998] Data 0.553 (0.553) Elapsed 16m 33s (remain 567m 58s) Loss: 0.5470(0.4740) Grad: 7.2479  
Epoch: [1][1180/40998] Data 0.553 (0.553) Elapsed 16m 50s (remain 567m 39s) Loss: 0.5810(0.4730) Grad: 5.1567  
Epoch: [1][1200/40998] Data 0.552 (0.553) Elapsed 17m 7s (remain 567m 21s) Loss: 0.3351(0.4714) Grad: 4.3528  
Epoch: [1][1220/40998] Data 0.552 (0.553) Elapsed 17m 24s (remain 567m 2s) Loss: 0.3622(0.4700) Grad: 3.8108  
Epoch: [1][1240/40998] Data 0.552 (0.553) Elapsed 17m 41s (remain 566m 44s) Loss: 0.4086(0.4690) Grad: 4.0388  
Epoch: [1][1260/40998] Data 0.553 (0.553) Elapsed 17m 58s (remain 566m 25s) Loss: 0.2695(0.4673) Grad: 2.8184  
Epoch: [1][1280/40998] Data 0.553 (0.553) Elapsed 18m 15s (remain 566m 7s) Loss: 0.4850(0.4659) Grad: 4.4943  
Epoch: [1][1300/40998] Data 0.553 (0.553) Elapsed 18m 32s (remain 565m 48s) Loss: 0.5135(0.4646) Grad: 6.2567  
Epoch: [1][1320/40998] Data 0.553 (0.553) Elapsed 18m 49s (remain 565m 30s) Loss: 0.2715(0.4629) Grad: 2.6880  
Epoch: [1][1340/40998] Data 0.553 (0.553) Elapsed 19m 6s (remain 565m 12s) Loss: 0.3877(0.4619) Grad: 3.2289  
Epoch: [1][1360/40998] Data 0.553 (0.553) Elapsed 19m 23s (remain 564m 53s) Loss: 0.3231(0.4608) Grad: 3.7210  
Epoch: [1][1380/40998] Data 0.553 (0.553) Elapsed 19m 40s (remain 564m 35s) Loss: 0.4805(0.4588) Grad: 5.1187  
Epoch: [1][1400/40998] Data 0.553 (0.553) Elapsed 19m 57s (remain 564m 17s) Loss: 0.1662(0.4573) Grad: 3.3010  
Epoch: [1][1420/40998] Data 0.553 (0.553) Elapsed 20m 14s (remain 563m 59s) Loss: 0.3386(0.4555) Grad: 4.0177  
Epoch: [1][1440/40998] Data 0.553 (0.553) Elapsed 20m 32s (remain 563m 41s) Loss: 0.6130(0.4542) Grad: 6.4697  
Epoch: [1][1460/40998] Data 0.553 (0.553) Elapsed 20m 49s (remain 563m 23s) Loss: 0.2941(0.4521) Grad: 3.5293  
Epoch: [1][1480/40998] Data 0.553 (0.553) Elapsed 21m 6s (remain 563m 4s) Loss: 0.4725(0.4512) Grad: 6.7537  
Epoch: [1][1500/40998] Data 0.552 (0.553) Elapsed 21m 23s (remain 562m 46s) Loss: 0.5451(0.4503) Grad: 5.2309  
Epoch: [1][1520/40998] Data 0.553 (0.553) Elapsed 21m 40s (remain 562m 28s) Loss: 0.2884(0.4493) Grad: 3.9010  
Epoch: [1][1540/40998] Data 0.552 (0.553) Elapsed 21m 57s (remain 562m 10s) Loss: 0.3341(0.4481) Grad: 2.8484  
Epoch: [1][1560/40998] Data 0.552 (0.553) Elapsed 22m 14s (remain 561m 52s) Loss: 0.2223(0.4467) Grad: 2.4516  
Epoch: [1][1580/40998] Data 0.553 (0.553) Elapsed 22m 31s (remain 561m 34s) Loss: 0.2608(0.4451) Grad: 4.6336  
Epoch: [1][1600/40998] Data 0.552 (0.553) Elapsed 22m 48s (remain 561m 17s) Loss: 0.3960(0.4441) Grad: 3.7286  
Epoch: [1][1620/40998] Data 0.552 (0.553) Elapsed 23m 5s (remain 560m 59s) Loss: 0.5299(0.4433) Grad: 6.1462  
Epoch: [1][1640/40998] Data 0.553 (0.553) Elapsed 23m 22s (remain 560m 41s) Loss: 0.1687(0.4425) Grad: 1.7323  
Epoch: [1][1660/40998] Data 0.553 (0.553) Elapsed 23m 39s (remain 560m 23s) Loss: 0.2183(0.4418) Grad: 2.6848  
Epoch: [1][1680/40998] Data 0.553 (0.553) Elapsed 23m 56s (remain 560m 5s) Loss: 0.4905(0.4404) Grad: 4.7968  
Epoch: [1][1700/40998] Data 0.553 (0.553) Elapsed 24m 13s (remain 559m 47s) Loss: 0.4574(0.4398) Grad: 4.9223  
Epoch: [1][1720/40998] Data 0.553 (0.553) Elapsed 24m 30s (remain 559m 29s) Loss: 0.2995(0.4394) Grad: 3.5081  
Epoch: [1][1740/40998] Data 0.553 (0.553) Elapsed 24m 47s (remain 559m 11s) Loss: 0.4469(0.4380) Grad: 3.8852  
Epoch: [1][1760/40998] Data 0.553 (0.553) Elapsed 25m 5s (remain 558m 54s) Loss: 0.4500(0.4374) Grad: 3.9600  
Epoch: [1][1780/40998] Data 0.552 (0.553) Elapsed 25m 22s (remain 558m 36s) Loss: 0.8086(0.4368) Grad: 5.9817  
Epoch: [1][1800/40998] Data 0.552 (0.553) Elapsed 25m 39s (remain 558m 18s) Loss: 0.2335(0.4357) Grad: 2.5538  
Epoch: [1][1820/40998] Data 0.553 (0.553) Elapsed 25m 56s (remain 558m 0s) Loss: 0.2386(0.4345) Grad: 3.7621  
Epoch: [1][1840/40998] Data 0.553 (0.553) Elapsed 26m 13s (remain 557m 43s) Loss: 0.1864(0.4334) Grad: 2.3264  
Epoch: [1][1860/40998] Data 0.551 (0.553) Elapsed 26m 30s (remain 557m 25s) Loss: 0.1028(0.4325) Grad: 1.5620  
Epoch: [1][1880/40998] Data 0.553 (0.553) Elapsed 26m 47s (remain 557m 7s) Loss: 0.1165(0.4315) Grad: 1.5070  
Epoch: [1][1900/40998] Data 0.553 (0.553) Elapsed 27m 4s (remain 556m 50s) Loss: 0.6323(0.4310) Grad: 4.1203  
Epoch: [1][1920/40998] Data 0.553 (0.553) Elapsed 27m 21s (remain 556m 32s) Loss: 0.4528(0.4302) Grad: 4.6585  
Epoch: [1][1940/40998] Data 0.552 (0.553) Elapsed 27m 38s (remain 556m 14s) Loss: 0.2887(0.4288) Grad: 3.0704  
Epoch: [1][1960/40998] Data 0.552 (0.553) Elapsed 27m 55s (remain 555m 56s) Loss: 0.2247(0.4284) Grad: 2.3573  
Epoch: [1][1980/40998] Data 0.553 (0.553) Elapsed 28m 12s (remain 555m 39s) Loss: 0.4611(0.4277) Grad: 3.7604  
Epoch: [1][2000/40998] Data 0.553 (0.553) Elapsed 28m 29s (remain 555m 21s) Loss: 0.3091(0.4266) Grad: 3.2972  
Epoch: [1][2020/40998] Data 0.552 (0.553) Elapsed 28m 46s (remain 555m 3s) Loss: 0.4423(0.4258) Grad: 3.3024  
Epoch: [1][2040/40998] Data 0.553 (0.553) Elapsed 29m 3s (remain 554m 46s) Loss: 0.3166(0.4249) Grad: 4.1661  
Epoch: [1][2060/40998] Data 0.553 (0.553) Elapsed 29m 20s (remain 554m 28s) Loss: 0.2119(0.4242) Grad: 3.0091  
Epoch: [1][2080/40998] Data 0.553 (0.553) Elapsed 29m 38s (remain 554m 11s) Loss: 0.1845(0.4236) Grad: 2.2240  
Epoch: [1][2100/40998] Data 0.553 (0.553) Elapsed 29m 55s (remain 553m 53s) Loss: 0.1722(0.4229) Grad: 1.4631  
Epoch: [1][2120/40998] Data 0.553 (0.553) Elapsed 30m 12s (remain 553m 35s) Loss: 0.4426(0.4222) Grad: 3.5704  
Epoch: [1][2140/40998] Data 0.552 (0.553) Elapsed 30m 29s (remain 553m 18s) Loss: 0.4389(0.4214) Grad: 4.6936  
Epoch: [1][2160/40998] Data 0.553 (0.553) Elapsed 30m 46s (remain 553m 0s) Loss: 0.3045(0.4207) Grad: 3.4256  
Epoch: [1][2180/40998] Data 0.553 (0.553) Elapsed 31m 3s (remain 552m 43s) Loss: 0.2534(0.4202) Grad: 3.2902  
Epoch: [1][2200/40998] Data 0.553 (0.553) Elapsed 31m 20s (remain 552m 25s) Loss: 0.1792(0.4189) Grad: 2.0067  
Epoch: [1][2220/40998] Data 0.555 (0.553) Elapsed 31m 37s (remain 552m 8s) Loss: 0.2259(0.4180) Grad: 3.2194  
Epoch: [1][2240/40998] Data 0.553 (0.553) Elapsed 31m 54s (remain 551m 50s) Loss: 0.2489(0.4174) Grad: 3.8718  
Epoch: [1][2260/40998] Data 0.553 (0.553) Elapsed 32m 11s (remain 551m 33s) Loss: 0.2170(0.4167) Grad: 2.2986  
Epoch: [1][2280/40998] Data 0.553 (0.553) Elapsed 32m 28s (remain 551m 15s) Loss: 0.4465(0.4161) Grad: 3.4946  
Epoch: [1][2300/40998] Data 0.553 (0.553) Elapsed 32m 45s (remain 550m 58s) Loss: 0.3473(0.4154) Grad: 3.8170  
Epoch: [1][2320/40998] Data 0.553 (0.553) Elapsed 33m 2s (remain 550m 40s) Loss: 0.4576(0.4147) Grad: 3.2494  
Epoch: [1][2340/40998] Data 0.551 (0.553) Elapsed 33m 19s (remain 550m 23s) Loss: 0.3566(0.4137) Grad: 2.9380  
Epoch: [1][2360/40998] Data 0.552 (0.553) Elapsed 33m 36s (remain 550m 5s) Loss: 0.1940(0.4128) Grad: 2.9704  
Epoch: [1][2380/40998] Data 0.553 (0.553) Elapsed 33m 53s (remain 549m 48s) Loss: 0.1885(0.4123) Grad: 2.3402  
Epoch: [1][2400/40998] Data 0.553 (0.553) Elapsed 34m 11s (remain 549m 30s) Loss: 0.3742(0.4118) Grad: 3.9260  
Epoch: [1][2420/40998] Data 0.553 (0.553) Elapsed 34m 28s (remain 549m 13s) Loss: 0.4667(0.4115) Grad: 3.1082  
Epoch: [1][2440/40998] Data 0.552 (0.553) Elapsed 34m 45s (remain 548m 55s) Loss: 0.3068(0.4108) Grad: 3.4124  
Epoch: [1][2460/40998] Data 0.553 (0.553) Elapsed 35m 2s (remain 548m 38s) Loss: 0.1703(0.4099) Grad: 2.3023  
Epoch: [1][2480/40998] Data 0.553 (0.553) Elapsed 35m 19s (remain 548m 20s) Loss: 0.3062(0.4092) Grad: 3.6527  
Epoch: [1][2500/40998] Data 0.553 (0.553) Elapsed 35m 36s (remain 548m 3s) Loss: 0.5162(0.4085) Grad: 5.0258  
Epoch: [1][2520/40998] Data 0.553 (0.553) Elapsed 35m 53s (remain 547m 46s) Loss: 0.3871(0.4076) Grad: 3.3546  
Epoch: [1][2540/40998] Data 0.552 (0.553) Elapsed 36m 10s (remain 547m 28s) Loss: 0.1952(0.4071) Grad: 2.2715  
Epoch: [1][2560/40998] Data 0.552 (0.553) Elapsed 36m 27s (remain 547m 11s) Loss: 0.2784(0.4067) Grad: 1.9070  
Epoch: [1][2580/40998] Data 0.553 (0.553) Elapsed 36m 44s (remain 546m 53s) Loss: 0.2288(0.4060) Grad: 2.5700  
Epoch: [1][2600/40998] Data 0.553 (0.553) Elapsed 37m 1s (remain 546m 36s) Loss: 0.8520(0.4052) Grad: 8.7013  
Epoch: [1][2620/40998] Data 0.552 (0.553) Elapsed 37m 18s (remain 546m 18s) Loss: 0.2892(0.4047) Grad: 2.6977  
Epoch: [1][2640/40998] Data 0.552 (0.553) Elapsed 37m 35s (remain 546m 1s) Loss: 0.2398(0.4041) Grad: 1.7540  
Epoch: [1][2660/40998] Data 0.553 (0.553) Elapsed 37m 52s (remain 545m 44s) Loss: 0.2134(0.4033) Grad: 1.8970  
Epoch: [1][2680/40998] Data 0.553 (0.553) Elapsed 38m 9s (remain 545m 26s) Loss: 0.1671(0.4026) Grad: 2.1689  
Epoch: [1][2700/40998] Data 0.553 (0.553) Elapsed 38m 26s (remain 545m 9s) Loss: 0.1771(0.4021) Grad: 1.6560  
Epoch: [1][2720/40998] Data 0.553 (0.553) Elapsed 38m 43s (remain 544m 51s) Loss: 0.2230(0.4016) Grad: 2.3146  
Epoch: [1][2740/40998] Data 0.553 (0.553) Elapsed 39m 1s (remain 544m 34s) Loss: 0.2867(0.4009) Grad: 2.9328  
Epoch: [1][2760/40998] Data 0.553 (0.553) Elapsed 39m 18s (remain 544m 17s) Loss: 0.2763(0.4005) Grad: 3.3305  
Epoch: [1][2780/40998] Data 0.553 (0.553) Elapsed 39m 35s (remain 543m 59s) Loss: 0.2568(0.4001) Grad: 2.3604  
Epoch: [1][2800/40998] Data 0.553 (0.553) Elapsed 39m 52s (remain 543m 42s) Loss: 0.3026(0.3994) Grad: 4.7383  
Epoch: [1][2820/40998] Data 0.552 (0.553) Elapsed 40m 9s (remain 543m 25s) Loss: 0.2071(0.3989) Grad: 2.3130  
Epoch: [1][2840/40998] Data 0.553 (0.553) Elapsed 40m 26s (remain 543m 7s) Loss: 0.3456(0.3981) Grad: 3.2836  
Epoch: [1][2860/40998] Data 0.553 (0.553) Elapsed 40m 43s (remain 542m 50s) Loss: 0.2130(0.3973) Grad: 2.6025  
Epoch: [1][2880/40998] Data 0.552 (0.553) Elapsed 41m 0s (remain 542m 33s) Loss: 0.4571(0.3966) Grad: 5.9418  
Epoch: [1][2900/40998] Data 0.553 (0.553) Elapsed 41m 17s (remain 542m 15s) Loss: 0.2048(0.3962) Grad: 3.8048  
Epoch: [1][2920/40998] Data 0.553 (0.553) Elapsed 41m 34s (remain 541m 58s) Loss: 0.2103(0.3954) Grad: 3.1009  
Epoch: [1][2940/40998] Data 0.553 (0.553) Elapsed 41m 51s (remain 541m 41s) Loss: 0.4970(0.3944) Grad: 4.4819  
Epoch: [1][2960/40998] Data 0.553 (0.553) Elapsed 42m 8s (remain 541m 23s) Loss: 0.4697(0.3940) Grad: 3.6993  
Epoch: [1][2980/40998] Data 0.552 (0.553) Elapsed 42m 25s (remain 541m 6s) Loss: 0.4027(0.3937) Grad: 2.8779  
Epoch: [1][3000/40998] Data 0.552 (0.553) Elapsed 42m 42s (remain 540m 49s) Loss: 0.3300(0.3933) Grad: 3.1115  
Epoch: [1][3020/40998] Data 0.553 (0.553) Elapsed 42m 59s (remain 540m 31s) Loss: 0.4063(0.3928) Grad: 5.5049  
Epoch: [1][3040/40998] Data 0.553 (0.553) Elapsed 43m 16s (remain 540m 14s) Loss: 0.2783(0.3925) Grad: 3.4340  
Epoch: [1][3060/40998] Data 0.552 (0.553) Elapsed 43m 34s (remain 539m 57s) Loss: 0.1879(0.3920) Grad: 2.0003  
Epoch: [1][3080/40998] Data 0.553 (0.553) Elapsed 43m 51s (remain 539m 39s) Loss: 0.2985(0.3916) Grad: 3.6128  
Epoch: [1][3100/40998] Data 0.552 (0.553) Elapsed 44m 8s (remain 539m 22s) Loss: 0.2297(0.3912) Grad: 3.2206  
Epoch: [1][3120/40998] Data 0.553 (0.553) Elapsed 44m 25s (remain 539m 5s) Loss: 0.6284(0.3911) Grad: 5.7342  
Epoch: [1][3140/40998] Data 0.553 (0.553) Elapsed 44m 42s (remain 538m 47s) Loss: 0.4407(0.3906) Grad: 4.5784  
Epoch: [1][3160/40998] Data 0.552 (0.553) Elapsed 44m 59s (remain 538m 30s) Loss: 0.1959(0.3900) Grad: 1.8897  
Epoch: [1][3180/40998] Data 0.552 (0.553) Elapsed 45m 16s (remain 538m 13s) Loss: 0.3527(0.3895) Grad: 4.0058  
Epoch: [1][3200/40998] Data 0.553 (0.553) Elapsed 45m 33s (remain 537m 55s) Loss: 0.4527(0.3890) Grad: 4.5512  
Epoch: [1][3220/40998] Data 0.552 (0.553) Elapsed 45m 50s (remain 537m 38s) Loss: 0.4978(0.3886) Grad: 3.6290  
Epoch: [1][3240/40998] Data 0.553 (0.553) Elapsed 46m 7s (remain 537m 21s) Loss: 0.2094(0.3880) Grad: 2.2683  
Epoch: [1][3260/40998] Data 0.553 (0.553) Elapsed 46m 24s (remain 537m 4s) Loss: 0.3202(0.3876) Grad: 3.4282  
Epoch: [1][3280/40998] Data 0.553 (0.553) Elapsed 46m 41s (remain 536m 46s) Loss: 0.2276(0.3870) Grad: 2.5730  
Epoch: [1][3300/40998] Data 0.553 (0.553) Elapsed 46m 58s (remain 536m 29s) Loss: 0.2211(0.3863) Grad: 2.2181  
Epoch: [1][3320/40998] Data 0.552 (0.553) Elapsed 47m 15s (remain 536m 12s) Loss: 0.3167(0.3857) Grad: 3.7030  
Epoch: [1][3340/40998] Data 0.553 (0.553) Elapsed 47m 32s (remain 535m 54s) Loss: 0.2307(0.3852) Grad: 3.2164  
Epoch: [1][3360/40998] Data 0.553 (0.553) Elapsed 47m 49s (remain 535m 37s) Loss: 0.2331(0.3849) Grad: 1.6425  
Epoch: [1][3380/40998] Data 0.553 (0.553) Elapsed 48m 6s (remain 535m 20s) Loss: 0.3989(0.3844) Grad: 3.6748  
Epoch: [1][3400/40998] Data 0.553 (0.553) Elapsed 48m 24s (remain 535m 3s) Loss: 0.3708(0.3839) Grad: 6.4415  
Epoch: [1][3420/40998] Data 0.553 (0.553) Elapsed 48m 41s (remain 534m 45s) Loss: 0.3647(0.3835) Grad: 3.6918  
Epoch: [1][3440/40998] Data 0.552 (0.553) Elapsed 48m 58s (remain 534m 28s) Loss: 0.3599(0.3833) Grad: 4.7906  
Epoch: [1][3460/40998] Data 0.552 (0.553) Elapsed 49m 15s (remain 534m 11s) Loss: 0.4861(0.3831) Grad: 4.1276  
Epoch: [1][3480/40998] Data 0.553 (0.553) Elapsed 49m 32s (remain 533m 54s) Loss: 0.2374(0.3827) Grad: 2.7598  
Epoch: [1][3500/40998] Data 0.553 (0.553) Elapsed 49m 49s (remain 533m 36s) Loss: 0.1992(0.3821) Grad: 3.7130  
Epoch: [1][3520/40998] Data 0.552 (0.553) Elapsed 50m 6s (remain 533m 19s) Loss: 0.3234(0.3818) Grad: 2.9725  
Epoch: [1][3540/40998] Data 0.553 (0.553) Elapsed 50m 23s (remain 533m 2s) Loss: 0.2880(0.3813) Grad: 3.5453  
Epoch: [1][3560/40998] Data 0.552 (0.553) Elapsed 50m 40s (remain 532m 45s) Loss: 0.2963(0.3808) Grad: 2.6271  
Epoch: [1][3580/40998] Data 0.554 (0.553) Elapsed 50m 57s (remain 532m 27s) Loss: 0.1611(0.3803) Grad: 2.6611  
Epoch: [1][3600/40998] Data 0.552 (0.553) Elapsed 51m 14s (remain 532m 10s) Loss: 0.3852(0.3798) Grad: 4.4395  
Epoch: [1][3620/40998] Data 0.552 (0.553) Elapsed 51m 31s (remain 531m 53s) Loss: 0.3651(0.3790) Grad: 3.9010  
Epoch: [1][3640/40998] Data 0.552 (0.553) Elapsed 51m 48s (remain 531m 36s) Loss: 0.2293(0.3784) Grad: 2.5971  
Epoch: [1][3660/40998] Data 0.553 (0.553) Elapsed 52m 5s (remain 531m 19s) Loss: 0.2678(0.3779) Grad: 2.4506  
Epoch: [1][3680/40998] Data 0.553 (0.553) Elapsed 52m 22s (remain 531m 1s) Loss: 0.1551(0.3773) Grad: 2.0285  
Epoch: [1][3700/40998] Data 0.552 (0.553) Elapsed 52m 39s (remain 530m 44s) Loss: 0.2663(0.3767) Grad: 3.4638  
Epoch: [1][3720/40998] Data 0.552 (0.553) Elapsed 52m 57s (remain 530m 27s) Loss: 0.2276(0.3761) Grad: 3.0073  
Epoch: [1][3740/40998] Data 0.552 (0.553) Elapsed 53m 14s (remain 530m 10s) Loss: 0.2513(0.3755) Grad: 3.2773  
Epoch: [1][3760/40998] Data 0.552 (0.553) Elapsed 53m 31s (remain 529m 52s) Loss: 0.2192(0.3753) Grad: 1.9776  
Epoch: [1][3780/40998] Data 0.550 (0.553) Elapsed 53m 48s (remain 529m 35s) Loss: 0.2868(0.3750) Grad: 3.6985  
Epoch: [1][3800/40998] Data 0.552 (0.553) Elapsed 54m 5s (remain 529m 18s) Loss: 0.1692(0.3746) Grad: 2.2160  
Epoch: [1][3820/40998] Data 0.553 (0.553) Elapsed 54m 22s (remain 529m 1s) Loss: 0.1742(0.3741) Grad: 1.9484  
Epoch: [1][3840/40998] Data 0.553 (0.553) Elapsed 54m 39s (remain 528m 44s) Loss: 0.3369(0.3737) Grad: 3.7118  
Epoch: [1][3860/40998] Data 0.553 (0.553) Elapsed 54m 56s (remain 528m 26s) Loss: 0.3194(0.3733) Grad: 4.3723  
Epoch: [1][3880/40998] Data 0.553 (0.553) Elapsed 55m 13s (remain 528m 9s) Loss: 0.2712(0.3727) Grad: 3.4874  
Epoch: [1][3900/40998] Data 0.553 (0.553) Elapsed 55m 30s (remain 527m 52s) Loss: 0.2289(0.3724) Grad: 2.4119  
Epoch: [1][3920/40998] Data 0.553 (0.553) Elapsed 55m 47s (remain 527m 35s) Loss: 0.2685(0.3719) Grad: 2.0318  
Epoch: [1][3940/40998] Data 0.553 (0.553) Elapsed 56m 4s (remain 527m 18s) Loss: 0.1566(0.3714) Grad: 2.2899  
Epoch: [1][3960/40998] Data 0.552 (0.553) Elapsed 56m 21s (remain 527m 0s) Loss: 0.2567(0.3708) Grad: 3.0348  
Epoch: [1][3980/40998] Data 0.552 (0.553) Elapsed 56m 38s (remain 526m 43s) Loss: 0.5365(0.3707) Grad: 3.9728  
Epoch: [1][4000/40998] Data 0.553 (0.553) Elapsed 56m 55s (remain 526m 26s) Loss: 0.2913(0.3704) Grad: 2.5863  
Epoch: [1][4020/40998] Data 0.553 (0.553) Elapsed 57m 12s (remain 526m 9s) Loss: 0.3015(0.3701) Grad: 3.8422  
Epoch: [1][4040/40998] Data 0.553 (0.553) Elapsed 57m 30s (remain 525m 52s) Loss: 0.2245(0.3698) Grad: 3.5343  
Epoch: [1][4060/40998] Data 0.552 (0.553) Elapsed 57m 47s (remain 525m 34s) Loss: 0.4113(0.3693) Grad: 3.2621  
Epoch: [1][4080/40998] Data 0.553 (0.553) Elapsed 58m 4s (remain 525m 17s) Loss: 0.1720(0.3691) Grad: 2.8943  
Epoch: [1][4100/40998] Data 0.553 (0.553) Elapsed 58m 21s (remain 525m 0s) Loss: 0.3074(0.3688) Grad: 3.4855  
Epoch: [1][4120/40998] Data 0.552 (0.553) Elapsed 58m 38s (remain 524m 43s) Loss: 0.4183(0.3683) Grad: 3.4040  
Epoch: [1][4140/40998] Data 0.553 (0.553) Elapsed 58m 55s (remain 524m 26s) Loss: 0.6403(0.3680) Grad: 7.0657  
Epoch: [1][4160/40998] Data 0.553 (0.553) Elapsed 59m 12s (remain 524m 8s) Loss: 0.1802(0.3676) Grad: 2.4993  
Epoch: [1][4180/40998] Data 0.553 (0.553) Elapsed 59m 29s (remain 523m 51s) Loss: 0.5426(0.3673) Grad: 5.0276  
Epoch: [1][4200/40998] Data 0.553 (0.553) Elapsed 59m 46s (remain 523m 34s) Loss: 0.2038(0.3667) Grad: 2.8629  
Epoch: [1][4220/40998] Data 0.553 (0.553) Elapsed 60m 3s (remain 523m 17s) Loss: 0.2117(0.3663) Grad: 2.2533  
Epoch: [1][4240/40998] Data 0.553 (0.553) Elapsed 60m 20s (remain 523m 0s) Loss: 0.3017(0.3660) Grad: 4.0852  
Epoch: [1][4260/40998] Data 0.553 (0.553) Elapsed 60m 37s (remain 522m 42s) Loss: 0.1612(0.3655) Grad: 2.9645  
Epoch: [1][4280/40998] Data 0.553 (0.553) Elapsed 60m 54s (remain 522m 25s) Loss: 0.8159(0.3652) Grad: 12.0382  
Epoch: [1][4300/40998] Data 0.553 (0.553) Elapsed 61m 11s (remain 522m 8s) Loss: 0.2168(0.3646) Grad: 2.7309  
Epoch: [1][4320/40998] Data 0.553 (0.553) Elapsed 61m 28s (remain 521m 51s) Loss: 0.1089(0.3644) Grad: 1.3551  
Epoch: [1][4340/40998] Data 0.552 (0.553) Elapsed 61m 45s (remain 521m 34s) Loss: 0.2641(0.3640) Grad: 2.3534  
Epoch: [1][4360/40998] Data 0.552 (0.553) Elapsed 62m 2s (remain 521m 17s) Loss: 0.0522(0.3634) Grad: 0.7182  
Epoch: [1][4380/40998] Data 0.552 (0.553) Elapsed 62m 20s (remain 520m 59s) Loss: 0.4386(0.3631) Grad: 3.6599  
Epoch: [1][4400/40998] Data 0.553 (0.553) Elapsed 62m 37s (remain 520m 42s) Loss: 0.3075(0.3630) Grad: 3.1810  
Epoch: [1][4420/40998] Data 0.552 (0.553) Elapsed 62m 54s (remain 520m 25s) Loss: 0.1202(0.3624) Grad: 1.5744  
Epoch: [1][4440/40998] Data 0.553 (0.553) Elapsed 63m 11s (remain 520m 8s) Loss: 0.1457(0.3618) Grad: 2.7289  
Epoch: [1][4460/40998] Data 0.553 (0.553) Elapsed 63m 28s (remain 519m 51s) Loss: 0.1105(0.3614) Grad: 2.1635  
Epoch: [1][4480/40998] Data 0.553 (0.553) Elapsed 63m 45s (remain 519m 33s) Loss: 0.5812(0.3614) Grad: 6.0290  
Epoch: [1][4500/40998] Data 0.553 (0.553) Elapsed 64m 2s (remain 519m 16s) Loss: 0.1963(0.3610) Grad: 2.1509  
Epoch: [1][4520/40998] Data 0.553 (0.553) Elapsed 64m 19s (remain 518m 59s) Loss: 0.1424(0.3606) Grad: 1.7593  
Epoch: [1][4540/40998] Data 0.553 (0.553) Elapsed 64m 36s (remain 518m 42s) Loss: 0.2137(0.3603) Grad: 4.0668  
Epoch: [1][4560/40998] Data 0.553 (0.553) Elapsed 64m 53s (remain 518m 25s) Loss: 0.4333(0.3600) Grad: 3.4192  
Epoch: [1][4580/40998] Data 0.553 (0.553) Elapsed 65m 10s (remain 518m 8s) Loss: 0.2566(0.3598) Grad: 3.0899  
Epoch: [1][4600/40998] Data 0.553 (0.553) Elapsed 65m 27s (remain 517m 50s) Loss: 0.2513(0.3593) Grad: 3.2703  
Epoch: [1][4620/40998] Data 0.553 (0.553) Elapsed 65m 44s (remain 517m 33s) Loss: 0.0455(0.3589) Grad: 0.7746  
Epoch: [1][4640/40998] Data 0.552 (0.553) Elapsed 66m 1s (remain 517m 16s) Loss: 0.1582(0.3585) Grad: 2.8523  
Epoch: [1][4660/40998] Data 0.552 (0.553) Elapsed 66m 18s (remain 516m 59s) Loss: 0.4539(0.3580) Grad: 3.6360  
Epoch: [1][4680/40998] Data 0.552 (0.553) Elapsed 66m 35s (remain 516m 42s) Loss: 0.4088(0.3579) Grad: 3.6551  
Epoch: [1][4700/40998] Data 0.553 (0.553) Elapsed 66m 53s (remain 516m 24s) Loss: 0.2001(0.3578) Grad: 1.6809  
Epoch: [1][4720/40998] Data 0.553 (0.553) Elapsed 67m 10s (remain 516m 7s) Loss: 0.1451(0.3574) Grad: 1.8845  
Epoch: [1][4740/40998] Data 0.552 (0.553) Elapsed 67m 27s (remain 515m 50s) Loss: 0.1173(0.3571) Grad: 1.7397  
Epoch: [1][4760/40998] Data 0.553 (0.553) Elapsed 67m 44s (remain 515m 33s) Loss: 0.1500(0.3568) Grad: 1.9881  
Epoch: [1][4780/40998] Data 0.553 (0.553) Elapsed 68m 1s (remain 515m 16s) Loss: 0.1083(0.3563) Grad: 1.5966  
Epoch: [1][4800/40998] Data 0.553 (0.553) Elapsed 68m 18s (remain 514m 59s) Loss: 1.3131(0.3562) Grad: 12.6580  
Epoch: [1][4820/40998] Data 0.553 (0.553) Elapsed 68m 35s (remain 514m 42s) Loss: 0.1663(0.3557) Grad: 1.6973  
Epoch: [1][4840/40998] Data 0.553 (0.553) Elapsed 68m 52s (remain 514m 24s) Loss: 0.1008(0.3552) Grad: 1.4107  
Epoch: [1][4860/40998] Data 0.553 (0.553) Elapsed 69m 9s (remain 514m 7s) Loss: 0.6115(0.3548) Grad: 5.3901  
Epoch: [1][4880/40998] Data 0.553 (0.553) Elapsed 69m 26s (remain 513m 50s) Loss: 0.2030(0.3543) Grad: 2.4576  
Epoch: [1][4900/40998] Data 0.553 (0.553) Elapsed 69m 43s (remain 513m 33s) Loss: 0.1836(0.3539) Grad: 2.8831  
Epoch: [1][4920/40998] Data 0.553 (0.553) Elapsed 70m 0s (remain 513m 16s) Loss: 0.3837(0.3535) Grad: 3.9599  
Epoch: [1][4940/40998] Data 0.553 (0.553) Elapsed 70m 17s (remain 512m 59s) Loss: 0.3085(0.3532) Grad: 5.1727  
Epoch: [1][4960/40998] Data 0.553 (0.553) Elapsed 70m 34s (remain 512m 41s) Loss: 0.4196(0.3530) Grad: 4.3328  
Epoch: [1][4980/40998] Data 0.553 (0.553) Elapsed 70m 51s (remain 512m 24s) Loss: 0.1627(0.3526) Grad: 1.9059  
Epoch: [1][5000/40998] Data 0.553 (0.553) Elapsed 71m 8s (remain 512m 7s) Loss: 0.1264(0.3524) Grad: 1.5472  
Epoch: [1][5020/40998] Data 0.553 (0.553) Elapsed 71m 26s (remain 511m 50s) Loss: 0.2456(0.3521) Grad: 4.0302  
Epoch: [1][5040/40998] Data 0.552 (0.553) Elapsed 71m 43s (remain 511m 33s) Loss: 0.2323(0.3518) Grad: 2.7378  
Epoch: [1][5060/40998] Data 0.552 (0.553) Elapsed 72m 0s (remain 511m 16s) Loss: 0.1625(0.3513) Grad: 3.9250  
Epoch: [1][5080/40998] Data 0.552 (0.553) Elapsed 72m 17s (remain 510m 59s) Loss: 0.1275(0.3509) Grad: 1.5964  
Epoch: [1][5100/40998] Data 0.553 (0.553) Elapsed 72m 34s (remain 510m 41s) Loss: 0.1923(0.3506) Grad: 2.5608  
Epoch: [1][5120/40998] Data 0.553 (0.553) Elapsed 72m 51s (remain 510m 24s) Loss: 0.3269(0.3503) Grad: 3.5254  
Epoch: [1][5140/40998] Data 0.553 (0.553) Elapsed 73m 8s (remain 510m 7s) Loss: 0.2733(0.3500) Grad: 2.8224  
Epoch: [1][5160/40998] Data 0.551 (0.553) Elapsed 73m 25s (remain 509m 50s) Loss: 0.1683(0.3497) Grad: 1.6668  
Epoch: [1][5180/40998] Data 0.553 (0.553) Elapsed 73m 42s (remain 509m 33s) Loss: 0.3279(0.3493) Grad: 2.4237  
Epoch: [1][5200/40998] Data 0.553 (0.553) Elapsed 73m 59s (remain 509m 16s) Loss: 0.1168(0.3491) Grad: 1.4203  
Epoch: [1][5220/40998] Data 0.553 (0.553) Elapsed 74m 16s (remain 508m 58s) Loss: 0.2589(0.3488) Grad: 3.2881  
Epoch: [1][5240/40998] Data 0.553 (0.553) Elapsed 74m 33s (remain 508m 41s) Loss: 0.1274(0.3483) Grad: 2.2733  
Epoch: [1][5260/40998] Data 0.553 (0.553) Elapsed 74m 50s (remain 508m 24s) Loss: 0.2523(0.3481) Grad: 3.0727  
Epoch: [1][5280/40998] Data 0.553 (0.553) Elapsed 75m 7s (remain 508m 7s) Loss: 0.3305(0.3479) Grad: 2.9528  
Epoch: [1][5300/40998] Data 0.552 (0.553) Elapsed 75m 24s (remain 507m 50s) Loss: 0.2258(0.3474) Grad: 2.5228  
Epoch: [1][5320/40998] Data 0.553 (0.553) Elapsed 75m 41s (remain 507m 33s) Loss: 0.3892(0.3472) Grad: 3.8145  
Epoch: [1][5340/40998] Data 0.552 (0.553) Elapsed 75m 58s (remain 507m 16s) Loss: 0.3844(0.3468) Grad: 5.6329  
Epoch: [1][5360/40998] Data 0.553 (0.553) Elapsed 76m 16s (remain 506m 58s) Loss: 0.7671(0.3466) Grad: 9.5603  
Epoch: [1][5380/40998] Data 0.552 (0.553) Elapsed 76m 33s (remain 506m 41s) Loss: 0.3167(0.3464) Grad: 5.0064  
Epoch: [1][5400/40998] Data 0.553 (0.553) Elapsed 76m 50s (remain 506m 24s) Loss: 0.2259(0.3461) Grad: 3.2661  
Epoch: [1][5420/40998] Data 0.553 (0.553) Elapsed 77m 7s (remain 506m 7s) Loss: 0.4502(0.3457) Grad: 3.3215  
Epoch: [1][5440/40998] Data 0.553 (0.553) Elapsed 77m 24s (remain 505m 50s) Loss: 0.2984(0.3454) Grad: 2.8672  
Epoch: [1][5460/40998] Data 0.553 (0.553) Elapsed 77m 41s (remain 505m 33s) Loss: 0.3277(0.3454) Grad: 2.9040  
Epoch: [1][5480/40998] Data 0.553 (0.553) Elapsed 77m 58s (remain 505m 16s) Loss: 0.1802(0.3451) Grad: 1.5881  
Epoch: [1][5500/40998] Data 0.553 (0.553) Elapsed 78m 15s (remain 504m 59s) Loss: 0.6011(0.3450) Grad: 5.3194  
Epoch: [1][5520/40998] Data 0.553 (0.553) Elapsed 78m 32s (remain 504m 41s) Loss: 0.0747(0.3447) Grad: 1.0549  
Epoch: [1][5540/40998] Data 0.553 (0.553) Elapsed 78m 49s (remain 504m 24s) Loss: 0.1258(0.3444) Grad: 2.4118  
Epoch: [1][5560/40998] Data 0.553 (0.553) Elapsed 79m 6s (remain 504m 7s) Loss: 0.2437(0.3441) Grad: 2.6214  
Epoch: [1][5580/40998] Data 0.553 (0.553) Elapsed 79m 23s (remain 503m 50s) Loss: 0.3223(0.3438) Grad: 2.8051  
Epoch: [1][5600/40998] Data 0.553 (0.553) Elapsed 79m 40s (remain 503m 33s) Loss: 0.2120(0.3434) Grad: 4.0831  
Epoch: [1][5620/40998] Data 0.551 (0.553) Elapsed 79m 57s (remain 503m 16s) Loss: 0.2394(0.3430) Grad: 3.0675  
Epoch: [1][5640/40998] Data 0.552 (0.553) Elapsed 80m 14s (remain 502m 59s) Loss: 0.0989(0.3429) Grad: 0.9570  
Epoch: [1][5660/40998] Data 0.553 (0.553) Elapsed 80m 31s (remain 502m 41s) Loss: 0.1940(0.3425) Grad: 3.8546  
Epoch: [1][5680/40998] Data 0.553 (0.553) Elapsed 80m 49s (remain 502m 24s) Loss: 0.1959(0.3422) Grad: 3.8876  
Epoch: [1][5700/40998] Data 0.553 (0.553) Elapsed 81m 6s (remain 502m 7s) Loss: 0.4826(0.3421) Grad: 5.0981  
Epoch: [1][5720/40998] Data 0.553 (0.553) Elapsed 81m 23s (remain 501m 50s) Loss: 0.2601(0.3420) Grad: 2.8846  
Epoch: [1][5740/40998] Data 0.553 (0.553) Elapsed 81m 40s (remain 501m 33s) Loss: 0.3377(0.3417) Grad: 3.2157  
Epoch: [1][5760/40998] Data 0.553 (0.553) Elapsed 81m 57s (remain 501m 16s) Loss: 0.1182(0.3413) Grad: 1.0749  
Epoch: [1][5780/40998] Data 0.553 (0.553) Elapsed 82m 14s (remain 500m 59s) Loss: 0.2757(0.3411) Grad: 4.0697  
Epoch: [1][5800/40998] Data 0.553 (0.553) Elapsed 82m 31s (remain 500m 42s) Loss: 0.2644(0.3408) Grad: 3.9125  
Epoch: [1][5820/40998] Data 0.553 (0.553) Elapsed 82m 48s (remain 500m 24s) Loss: 0.1941(0.3405) Grad: 2.2366  
Epoch: [1][5840/40998] Data 0.553 (0.553) Elapsed 83m 5s (remain 500m 7s) Loss: 0.1492(0.3403) Grad: 2.0205  
Epoch: [1][5860/40998] Data 0.553 (0.553) Elapsed 83m 22s (remain 499m 50s) Loss: 0.1622(0.3401) Grad: 2.2099  
Epoch: [1][5880/40998] Data 0.552 (0.553) Elapsed 83m 39s (remain 499m 33s) Loss: 0.1354(0.3399) Grad: 1.9441  
Epoch: [1][5900/40998] Data 0.553 (0.553) Elapsed 83m 56s (remain 499m 16s) Loss: 0.2453(0.3396) Grad: 3.3732  
Epoch: [1][5920/40998] Data 0.552 (0.553) Elapsed 84m 13s (remain 498m 59s) Loss: 0.4683(0.3394) Grad: 3.3761  
Epoch: [1][5940/40998] Data 0.553 (0.553) Elapsed 84m 30s (remain 498m 42s) Loss: 0.2050(0.3392) Grad: 2.1987  
Epoch: [1][5960/40998] Data 0.553 (0.553) Elapsed 84m 47s (remain 498m 24s) Loss: 0.1026(0.3389) Grad: 1.4665  
Epoch: [1][5980/40998] Data 0.553 (0.553) Elapsed 85m 4s (remain 498m 7s) Loss: 0.5104(0.3385) Grad: 4.8246  
Epoch: [1][6000/40998] Data 0.553 (0.553) Elapsed 85m 21s (remain 497m 50s) Loss: 0.6427(0.3383) Grad: 5.2195  
Epoch: [1][6020/40998] Data 0.553 (0.553) Elapsed 85m 39s (remain 497m 33s) Loss: 0.3244(0.3380) Grad: 2.8280  
Epoch: [1][6040/40998] Data 0.553 (0.553) Elapsed 85m 56s (remain 497m 16s) Loss: 0.2403(0.3376) Grad: 3.4935  
Epoch: [1][6060/40998] Data 0.552 (0.553) Elapsed 86m 13s (remain 496m 59s) Loss: 0.1683(0.3372) Grad: 2.1730  
Epoch: [1][6080/40998] Data 0.553 (0.553) Elapsed 86m 30s (remain 496m 42s) Loss: 0.1496(0.3369) Grad: 2.8679  
Epoch: [1][6100/40998] Data 0.553 (0.553) Elapsed 86m 47s (remain 496m 25s) Loss: 0.1647(0.3366) Grad: 2.5862  
Epoch: [1][6120/40998] Data 0.553 (0.553) Elapsed 87m 4s (remain 496m 7s) Loss: 0.2002(0.3365) Grad: 1.6921  
Epoch: [1][6140/40998] Data 0.553 (0.553) Elapsed 87m 21s (remain 495m 50s) Loss: 0.1768(0.3361) Grad: 2.3348  
Epoch: [1][6160/40998] Data 0.554 (0.553) Elapsed 87m 38s (remain 495m 33s) Loss: 0.1564(0.3358) Grad: 2.0361  
Epoch: [1][6180/40998] Data 0.553 (0.553) Elapsed 87m 55s (remain 495m 16s) Loss: 0.1655(0.3355) Grad: 3.4164  
Epoch: [1][6200/40998] Data 0.552 (0.553) Elapsed 88m 12s (remain 494m 59s) Loss: 0.1463(0.3353) Grad: 2.1613  
Epoch: [1][6220/40998] Data 0.553 (0.553) Elapsed 88m 29s (remain 494m 42s) Loss: 0.1725(0.3351) Grad: 2.3651  
Epoch: [1][6240/40998] Data 0.553 (0.553) Elapsed 88m 46s (remain 494m 25s) Loss: 0.1361(0.3348) Grad: 2.0024  
Epoch: [1][6260/40998] Data 0.553 (0.553) Elapsed 89m 3s (remain 494m 8s) Loss: 0.4062(0.3346) Grad: 4.4391  
Epoch: [1][6280/40998] Data 0.552 (0.553) Elapsed 89m 20s (remain 493m 51s) Loss: 0.0795(0.3343) Grad: 1.3309  
Epoch: [1][6300/40998] Data 0.552 (0.553) Elapsed 89m 37s (remain 493m 33s) Loss: 0.2992(0.3340) Grad: 4.3284  
Epoch: [1][6320/40998] Data 0.553 (0.553) Elapsed 89m 54s (remain 493m 16s) Loss: 0.5671(0.3336) Grad: 4.3166  
Epoch: [1][6340/40998] Data 0.553 (0.553) Elapsed 90m 12s (remain 492m 59s) Loss: 0.1979(0.3333) Grad: 2.2547  
Epoch: [1][6360/40998] Data 0.552 (0.553) Elapsed 90m 29s (remain 492m 42s) Loss: 0.5255(0.3332) Grad: 4.7358  
Epoch: [1][6380/40998] Data 0.553 (0.553) Elapsed 90m 46s (remain 492m 25s) Loss: 0.2170(0.3329) Grad: 3.2626  
Epoch: [1][6400/40998] Data 0.552 (0.553) Elapsed 91m 3s (remain 492m 8s) Loss: 0.1909(0.3326) Grad: 2.6825  
Epoch: [1][6420/40998] Data 0.553 (0.553) Elapsed 91m 20s (remain 491m 51s) Loss: 0.2815(0.3323) Grad: 2.9678  
Epoch: [1][6440/40998] Data 0.553 (0.553) Elapsed 91m 37s (remain 491m 34s) Loss: 0.4191(0.3321) Grad: 3.4090  
Epoch: [1][6460/40998] Data 0.553 (0.553) Elapsed 91m 54s (remain 491m 17s) Loss: 0.1470(0.3319) Grad: 2.0783  
Epoch: [1][6480/40998] Data 0.553 (0.553) Elapsed 92m 11s (remain 490m 59s) Loss: 0.2028(0.3316) Grad: 2.5423  
Epoch: [1][6500/40998] Data 0.553 (0.553) Elapsed 92m 28s (remain 490m 42s) Loss: 0.4085(0.3314) Grad: 3.5672  
Epoch: [1][6520/40998] Data 0.553 (0.553) Elapsed 92m 45s (remain 490m 25s) Loss: 0.3225(0.3311) Grad: 2.9536  
Epoch: [1][6540/40998] Data 0.553 (0.553) Elapsed 93m 2s (remain 490m 8s) Loss: 0.1725(0.3308) Grad: 2.4769  
Epoch: [1][6560/40998] Data 0.552 (0.553) Elapsed 93m 19s (remain 489m 51s) Loss: 0.2230(0.3307) Grad: 3.6438  
Epoch: [1][6580/40998] Data 0.553 (0.553) Elapsed 93m 36s (remain 489m 34s) Loss: 0.4074(0.3306) Grad: 3.9932  
Epoch: [1][6600/40998] Data 0.552 (0.553) Elapsed 93m 53s (remain 489m 17s) Loss: 0.3403(0.3304) Grad: 2.7387  
Epoch: [1][6620/40998] Data 0.552 (0.553) Elapsed 94m 10s (remain 489m 0s) Loss: 0.1046(0.3301) Grad: 5.3190  
Epoch: [1][6640/40998] Data 0.552 (0.553) Elapsed 94m 27s (remain 488m 43s) Loss: 0.1403(0.3298) Grad: 2.6859  
Epoch: [1][6660/40998] Data 0.552 (0.553) Elapsed 94m 45s (remain 488m 25s) Loss: 0.2339(0.3295) Grad: 4.2010  
Epoch: [1][6680/40998] Data 0.553 (0.553) Elapsed 95m 2s (remain 488m 8s) Loss: 0.2946(0.3292) Grad: 3.1304  
Epoch: [1][6700/40998] Data 0.553 (0.553) Elapsed 95m 19s (remain 487m 51s) Loss: 0.4629(0.3289) Grad: 3.6245  
Epoch: [1][6720/40998] Data 0.552 (0.553) Elapsed 95m 36s (remain 487m 34s) Loss: 0.2700(0.3287) Grad: 2.6775  
Epoch: [1][6740/40998] Data 0.552 (0.553) Elapsed 95m 53s (remain 487m 17s) Loss: 0.4529(0.3283) Grad: 7.8528  
Epoch: [1][6760/40998] Data 0.553 (0.553) Elapsed 96m 10s (remain 487m 0s) Loss: 0.2479(0.3279) Grad: 2.9501  
Epoch: [1][6780/40998] Data 0.552 (0.553) Elapsed 96m 27s (remain 486m 43s) Loss: 0.3409(0.3277) Grad: 3.2633  
Epoch: [1][6800/40998] Data 0.553 (0.553) Elapsed 96m 44s (remain 486m 26s) Loss: 0.1826(0.3273) Grad: 2.0594  
Epoch: [1][6820/40998] Data 0.552 (0.553) Elapsed 97m 1s (remain 486m 9s) Loss: 0.3317(0.3269) Grad: 3.5999  
Epoch: [1][6840/40998] Data 0.553 (0.553) Elapsed 97m 18s (remain 485m 52s) Loss: 0.5959(0.3267) Grad: 7.9406  
Epoch: [1][6860/40998] Data 0.553 (0.553) Elapsed 97m 35s (remain 485m 34s) Loss: 0.1806(0.3265) Grad: 1.9975  
Epoch: [1][6880/40998] Data 0.552 (0.553) Elapsed 97m 52s (remain 485m 17s) Loss: 0.3642(0.3263) Grad: 2.6956  
Epoch: [1][6900/40998] Data 0.553 (0.553) Elapsed 98m 9s (remain 485m 0s) Loss: 0.3441(0.3261) Grad: 3.6517  
Epoch: [1][6920/40998] Data 0.552 (0.553) Elapsed 98m 26s (remain 484m 43s) Loss: 0.4589(0.3257) Grad: 4.8953  
Epoch: [1][6940/40998] Data 0.553 (0.553) Elapsed 98m 43s (remain 484m 26s) Loss: 0.4253(0.3256) Grad: 3.7347  
Epoch: [1][6960/40998] Data 0.553 (0.553) Elapsed 99m 0s (remain 484m 9s) Loss: 0.1996(0.3254) Grad: 1.4495  
Epoch: [1][6980/40998] Data 0.553 (0.553) Elapsed 99m 18s (remain 483m 52s) Loss: 0.1665(0.3251) Grad: 2.1453  
Epoch: [1][7000/40998] Data 0.553 (0.553) Elapsed 99m 35s (remain 483m 35s) Loss: 0.2467(0.3247) Grad: 3.0714  
Epoch: [1][7020/40998] Data 0.553 (0.553) Elapsed 99m 52s (remain 483m 18s) Loss: 0.1989(0.3244) Grad: 1.5400  
Epoch: [1][7040/40998] Data 0.553 (0.553) Elapsed 100m 9s (remain 483m 1s) Loss: 0.2887(0.3242) Grad: 4.0853  
Epoch: [1][7060/40998] Data 0.553 (0.553) Elapsed 100m 26s (remain 482m 44s) Loss: 0.5693(0.3239) Grad: 6.2781  
Epoch: [1][7080/40998] Data 0.552 (0.553) Elapsed 100m 43s (remain 482m 26s) Loss: 0.2596(0.3237) Grad: 4.4527  
Epoch: [1][7100/40998] Data 0.553 (0.553) Elapsed 101m 0s (remain 482m 9s) Loss: 0.0502(0.3233) Grad: 1.4079  
Epoch: [1][7120/40998] Data 0.552 (0.553) Elapsed 101m 17s (remain 481m 52s) Loss: 0.3165(0.3231) Grad: 2.9040  
Epoch: [1][7140/40998] Data 0.553 (0.553) Elapsed 101m 34s (remain 481m 35s) Loss: 0.2737(0.3229) Grad: 3.4684  
Epoch: [1][7160/40998] Data 0.553 (0.553) Elapsed 101m 51s (remain 481m 18s) Loss: 0.2868(0.3228) Grad: 2.9093  
Epoch: [1][7180/40998] Data 0.553 (0.553) Elapsed 102m 8s (remain 481m 1s) Loss: 0.2075(0.3225) Grad: 2.4910  
Epoch: [1][7200/40998] Data 0.553 (0.553) Elapsed 102m 25s (remain 480m 44s) Loss: 0.0768(0.3221) Grad: 1.4086  
Epoch: [1][7220/40998] Data 0.553 (0.553) Elapsed 102m 42s (remain 480m 27s) Loss: 0.1626(0.3219) Grad: 2.2545  
Epoch: [1][7240/40998] Data 0.553 (0.553) Elapsed 102m 59s (remain 480m 10s) Loss: 0.3435(0.3217) Grad: 2.3138  
Epoch: [1][7260/40998] Data 0.553 (0.553) Elapsed 103m 16s (remain 479m 53s) Loss: 0.1875(0.3214) Grad: 2.7984  
Epoch: [1][7280/40998] Data 0.552 (0.553) Elapsed 103m 34s (remain 479m 36s) Loss: 0.0248(0.3213) Grad: 0.5163  
Epoch: [1][7300/40998] Data 0.552 (0.553) Elapsed 103m 51s (remain 479m 18s) Loss: 0.2297(0.3213) Grad: 2.2662  
Epoch: [1][7320/40998] Data 0.553 (0.553) Elapsed 104m 8s (remain 479m 1s) Loss: 0.4000(0.3212) Grad: 3.7396  
Epoch: [1][7340/40998] Data 0.553 (0.553) Elapsed 104m 25s (remain 478m 44s) Loss: 0.3687(0.3209) Grad: 3.2830  
Epoch: [1][7360/40998] Data 0.553 (0.553) Elapsed 104m 42s (remain 478m 27s) Loss: 0.1503(0.3207) Grad: 1.4308  
Epoch: [1][7380/40998] Data 0.553 (0.553) Elapsed 104m 59s (remain 478m 10s) Loss: 0.1425(0.3207) Grad: 1.3802  
Epoch: [1][7400/40998] Data 0.553 (0.553) Elapsed 105m 16s (remain 477m 53s) Loss: 0.3844(0.3205) Grad: 3.5767  
Epoch: [1][7420/40998] Data 0.552 (0.553) Elapsed 105m 33s (remain 477m 36s) Loss: 0.0928(0.3203) Grad: 0.9478  
Epoch: [1][7440/40998] Data 0.553 (0.553) Elapsed 105m 50s (remain 477m 19s) Loss: 0.3098(0.3200) Grad: 3.4717  
Epoch: [1][7460/40998] Data 0.554 (0.553) Elapsed 106m 7s (remain 477m 2s) Loss: 0.3247(0.3198) Grad: 3.5050  
Epoch: [1][7480/40998] Data 0.552 (0.553) Elapsed 106m 24s (remain 476m 45s) Loss: 0.1858(0.3195) Grad: 1.9472  
Epoch: [1][7500/40998] Data 0.553 (0.553) Elapsed 106m 41s (remain 476m 28s) Loss: 0.3041(0.3193) Grad: 3.9544  
Epoch: [1][7520/40998] Data 0.553 (0.553) Elapsed 106m 58s (remain 476m 10s) Loss: 0.1156(0.3191) Grad: 1.3310  
Epoch: [1][7540/40998] Data 0.553 (0.553) Elapsed 107m 15s (remain 475m 53s) Loss: 0.1100(0.3189) Grad: 2.3017  
Epoch: [1][7560/40998] Data 0.552 (0.553) Elapsed 107m 32s (remain 475m 36s) Loss: 0.1350(0.3187) Grad: 2.3109  
Epoch: [1][7580/40998] Data 0.553 (0.553) Elapsed 107m 49s (remain 475m 19s) Loss: 0.1409(0.3184) Grad: 2.2654  
Epoch: [1][7600/40998] Data 0.552 (0.553) Elapsed 108m 7s (remain 475m 2s) Loss: 0.3467(0.3181) Grad: 5.2347  
Epoch: [1][7620/40998] Data 0.553 (0.553) Elapsed 108m 24s (remain 474m 45s) Loss: 0.2367(0.3179) Grad: 3.5917  
Epoch: [1][7640/40998] Data 0.552 (0.553) Elapsed 108m 41s (remain 474m 28s) Loss: 0.1584(0.3176) Grad: 1.9672  
Epoch: [1][7660/40998] Data 0.553 (0.553) Elapsed 108m 58s (remain 474m 11s) Loss: 0.1541(0.3174) Grad: 1.5600  
Epoch: [1][7680/40998] Data 0.553 (0.553) Elapsed 109m 15s (remain 473m 54s) Loss: 0.1570(0.3171) Grad: 2.4373  
Epoch: [1][7700/40998] Data 0.553 (0.553) Elapsed 109m 32s (remain 473m 37s) Loss: 0.4842(0.3172) Grad: 4.4670  
Epoch: [1][7720/40998] Data 0.552 (0.553) Elapsed 109m 49s (remain 473m 20s) Loss: 0.1896(0.3169) Grad: 4.2196  
Epoch: [1][7740/40998] Data 0.553 (0.553) Elapsed 110m 6s (remain 473m 2s) Loss: 0.1877(0.3167) Grad: 2.8570  
Epoch: [1][7760/40998] Data 0.553 (0.553) Elapsed 110m 23s (remain 472m 45s) Loss: 0.2565(0.3164) Grad: 2.3546  
Epoch: [1][7780/40998] Data 0.553 (0.553) Elapsed 110m 40s (remain 472m 28s) Loss: 0.1454(0.3162) Grad: 2.0119  
Epoch: [1][7800/40998] Data 0.553 (0.553) Elapsed 110m 57s (remain 472m 11s) Loss: 0.1571(0.3160) Grad: 1.7658  
Epoch: [1][7820/40998] Data 0.553 (0.553) Elapsed 111m 14s (remain 471m 54s) Loss: 0.4768(0.3158) Grad: 6.5107  
Epoch: [1][7840/40998] Data 0.552 (0.553) Elapsed 111m 31s (remain 471m 37s) Loss: 0.1246(0.3155) Grad: 1.4379  
Epoch: [1][7860/40998] Data 0.553 (0.553) Elapsed 111m 48s (remain 471m 20s) Loss: 0.2261(0.3154) Grad: 1.9103  
Epoch: [1][7880/40998] Data 0.553 (0.553) Elapsed 112m 5s (remain 471m 3s) Loss: 0.2142(0.3152) Grad: 4.0401  
Epoch: [1][7900/40998] Data 0.552 (0.553) Elapsed 112m 23s (remain 470m 46s) Loss: 0.2710(0.3150) Grad: 2.2426  
Epoch: [1][7920/40998] Data 0.553 (0.553) Elapsed 112m 40s (remain 470m 29s) Loss: 0.4663(0.3147) Grad: 8.0442  
Epoch: [1][7940/40998] Data 0.553 (0.553) Elapsed 112m 57s (remain 470m 12s) Loss: 0.2715(0.3146) Grad: 2.1494  
Epoch: [1][7960/40998] Data 0.553 (0.553) Elapsed 113m 14s (remain 469m 54s) Loss: 0.1994(0.3144) Grad: 3.0286  
Epoch: [1][7980/40998] Data 0.553 (0.553) Elapsed 113m 31s (remain 469m 37s) Loss: 0.3840(0.3142) Grad: 5.5330  
Epoch: [1][8000/40998] Data 0.552 (0.553) Elapsed 113m 48s (remain 469m 20s) Loss: 0.5636(0.3139) Grad: 6.6573  
Epoch: [1][8020/40998] Data 0.553 (0.553) Elapsed 114m 5s (remain 469m 3s) Loss: 0.2956(0.3138) Grad: 2.9597  
Epoch: [1][8040/40998] Data 0.553 (0.553) Elapsed 114m 22s (remain 468m 46s) Loss: 0.3440(0.3136) Grad: 4.0119  
Epoch: [1][8060/40998] Data 0.553 (0.553) Elapsed 114m 39s (remain 468m 29s) Loss: 0.2996(0.3133) Grad: 3.5992  
Epoch: [1][8080/40998] Data 0.553 (0.553) Elapsed 114m 56s (remain 468m 14s) Loss: 0.3323(0.3130) Grad: 2.7511  
Epoch: [1][8100/40998] Data 0.552 (0.553) Elapsed 115m 13s (remain 467m 55s) Loss: 0.1447(0.3129) Grad: 1.7840  
Epoch: [1][8120/40998] Data 0.553 (0.553) Elapsed 115m 30s (remain 467m 38s) Loss: 0.0963(0.3127) Grad: 2.1877  
Epoch: [1][8140/40998] Data 0.553 (0.553) Elapsed 115m 47s (remain 467m 21s) Loss: 0.2763(0.3124) Grad: 2.7835  
Epoch: [1][8160/40998] Data 0.553 (0.553) Elapsed 116m 4s (remain 467m 4s) Loss: 0.1076(0.3123) Grad: 1.5792  
Epoch: [1][8180/40998] Data 0.553 (0.553) Elapsed 116m 21s (remain 466m 47s) Loss: 0.0530(0.3121) Grad: 0.7544  
Epoch: [1][8200/40998] Data 0.553 (0.553) Elapsed 116m 38s (remain 466m 29s) Loss: 0.3634(0.3120) Grad: 4.5098  
Epoch: [1][8220/40998] Data 0.553 (0.553) Elapsed 116m 56s (remain 466m 12s) Loss: 0.0977(0.3118) Grad: 1.1528  
Epoch: [1][8240/40998] Data 0.553 (0.553) Elapsed 117m 13s (remain 465m 55s) Loss: 0.2761(0.3115) Grad: 3.2182  
Epoch: [1][8260/40998] Data 0.551 (0.553) Elapsed 117m 30s (remain 465m 38s) Loss: 0.1904(0.3112) Grad: 3.0004  
Epoch: [1][8280/40998] Data 0.552 (0.553) Elapsed 117m 47s (remain 465m 21s) Loss: 0.3905(0.3112) Grad: 4.3826  
Epoch: [1][8300/40998] Data 0.553 (0.553) Elapsed 118m 4s (remain 465m 4s) Loss: 0.2046(0.3110) Grad: 3.5094  
Epoch: [1][8320/40998] Data 0.552 (0.553) Elapsed 118m 21s (remain 464m 47s) Loss: 0.1888(0.3107) Grad: 2.5930  
Epoch: [1][8340/40998] Data 0.553 (0.553) Elapsed 118m 38s (remain 464m 30s) Loss: 0.1149(0.3103) Grad: 2.0985  
Epoch: [1][8360/40998] Data 0.553 (0.553) Elapsed 118m 55s (remain 464m 13s) Loss: 0.0526(0.3102) Grad: 0.8573  
Epoch: [1][8380/40998] Data 0.553 (0.553) Elapsed 119m 12s (remain 463m 56s) Loss: 0.1776(0.3100) Grad: 1.3333  
Epoch: [1][8400/40998] Data 0.553 (0.553) Elapsed 119m 29s (remain 463m 39s) Loss: 0.2564(0.3098) Grad: 3.7807  
Epoch: [1][8420/40998] Data 0.552 (0.553) Elapsed 119m 46s (remain 463m 21s) Loss: 0.0605(0.3095) Grad: 1.9135  
Epoch: [1][8440/40998] Data 0.553 (0.553) Elapsed 120m 3s (remain 463m 4s) Loss: 0.0925(0.3092) Grad: 1.5798  
Epoch: [1][8460/40998] Data 0.552 (0.553) Elapsed 120m 20s (remain 462m 47s) Loss: 0.2237(0.3090) Grad: 3.4717  
Epoch: [1][8480/40998] Data 0.552 (0.553) Elapsed 120m 37s (remain 462m 30s) Loss: 0.1953(0.3086) Grad: 2.2369  
Epoch: [1][8500/40998] Data 0.553 (0.553) Elapsed 120m 54s (remain 462m 13s) Loss: 0.2648(0.3084) Grad: 2.4650  
Epoch: [1][8520/40998] Data 0.553 (0.553) Elapsed 121m 11s (remain 461m 56s) Loss: 0.1297(0.3082) Grad: 2.3939  
Epoch: [1][8540/40998] Data 0.553 (0.553) Elapsed 121m 29s (remain 461m 39s) Loss: 0.2249(0.3080) Grad: 2.3639  
Epoch: [1][8560/40998] Data 0.553 (0.553) Elapsed 121m 46s (remain 461m 22s) Loss: 0.1523(0.3079) Grad: 2.2518  
Epoch: [1][8580/40998] Data 0.552 (0.553) Elapsed 122m 3s (remain 461m 5s) Loss: 0.1525(0.3076) Grad: 1.8975  
Epoch: [1][8600/40998] Data 0.553 (0.553) Elapsed 122m 20s (remain 460m 48s) Loss: 0.5005(0.3076) Grad: 5.8201  
Epoch: [1][8620/40998] Data 0.552 (0.553) Elapsed 122m 37s (remain 460m 31s) Loss: 0.2401(0.3074) Grad: 2.6190  
Epoch: [1][8640/40998] Data 0.552 (0.553) Elapsed 122m 54s (remain 460m 13s) Loss: 0.0555(0.3072) Grad: 1.1835  
Epoch: [1][8660/40998] Data 0.553 (0.553) Elapsed 123m 11s (remain 459m 56s) Loss: 0.1641(0.3070) Grad: 2.4011  
Epoch: [1][8680/40998] Data 0.553 (0.553) Elapsed 123m 28s (remain 459m 39s) Loss: 0.2482(0.3067) Grad: 3.2998  
Epoch: [1][8700/40998] Data 0.553 (0.553) Elapsed 123m 45s (remain 459m 22s) Loss: 0.1856(0.3065) Grad: 2.1897  
Epoch: [1][8720/40998] Data 0.553 (0.553) Elapsed 124m 2s (remain 459m 5s) Loss: 0.1369(0.3063) Grad: 2.0961  
Epoch: [1][8740/40998] Data 0.553 (0.553) Elapsed 124m 19s (remain 458m 48s) Loss: 0.3175(0.3061) Grad: 3.3475  
Epoch: [1][8760/40998] Data 0.553 (0.553) Elapsed 124m 36s (remain 458m 31s) Loss: 0.1796(0.3061) Grad: 1.3738  
Epoch: [1][8780/40998] Data 0.552 (0.553) Elapsed 124m 53s (remain 458m 14s) Loss: 0.3632(0.3060) Grad: 4.6801  
Epoch: [1][8800/40998] Data 0.552 (0.553) Elapsed 125m 10s (remain 457m 57s) Loss: 0.1331(0.3057) Grad: 2.1122  
Epoch: [1][8820/40998] Data 0.553 (0.553) Elapsed 125m 27s (remain 457m 40s) Loss: 0.3033(0.3055) Grad: 4.5069  
Epoch: [1][8840/40998] Data 0.553 (0.553) Elapsed 125m 44s (remain 457m 23s) Loss: 0.1644(0.3052) Grad: 2.6688  
Epoch: [1][8860/40998] Data 0.553 (0.553) Elapsed 126m 2s (remain 457m 5s) Loss: 0.2470(0.3050) Grad: 3.8769  
Epoch: [1][8880/40998] Data 0.553 (0.553) Elapsed 126m 19s (remain 456m 48s) Loss: 0.1042(0.3049) Grad: 1.5973  
Epoch: [1][8900/40998] Data 0.553 (0.553) Elapsed 126m 36s (remain 456m 31s) Loss: 0.0957(0.3048) Grad: 1.4324  
Epoch: [1][8920/40998] Data 0.553 (0.553) Elapsed 126m 53s (remain 456m 14s) Loss: 0.0964(0.3046) Grad: 1.3578  
Epoch: [1][8940/40998] Data 0.553 (0.553) Elapsed 127m 10s (remain 455m 57s) Loss: 0.0837(0.3043) Grad: 1.1818  
Epoch: [1][8960/40998] Data 0.553 (0.553) Elapsed 127m 27s (remain 455m 40s) Loss: 0.3886(0.3042) Grad: 3.1877  
Epoch: [1][8980/40998] Data 0.553 (0.553) Elapsed 127m 44s (remain 455m 23s) Loss: 0.1375(0.3040) Grad: 1.4285  
Epoch: [1][9000/40998] Data 0.552 (0.553) Elapsed 128m 1s (remain 455m 6s) Loss: 0.1663(0.3038) Grad: 1.9197  
Epoch: [1][9020/40998] Data 0.553 (0.553) Elapsed 128m 18s (remain 454m 49s) Loss: 0.2058(0.3036) Grad: 2.2798  
Epoch: [1][9040/40998] Data 0.553 (0.553) Elapsed 128m 35s (remain 454m 32s) Loss: 0.0668(0.3033) Grad: 0.9697  
Epoch: [1][9060/40998] Data 0.553 (0.553) Elapsed 128m 52s (remain 454m 15s) Loss: 0.3560(0.3031) Grad: 3.1274  
Epoch: [1][9080/40998] Data 0.553 (0.553) Elapsed 129m 9s (remain 453m 57s) Loss: 0.3393(0.3029) Grad: 3.6336  
Epoch: [1][9100/40998] Data 0.553 (0.553) Elapsed 129m 26s (remain 453m 40s) Loss: 0.3012(0.3028) Grad: 3.0730  
Epoch: [1][9120/40998] Data 0.553 (0.553) Elapsed 129m 43s (remain 453m 23s) Loss: 0.2861(0.3026) Grad: 2.6326  
Epoch: [1][9140/40998] Data 0.553 (0.553) Elapsed 130m 0s (remain 453m 6s) Loss: 0.0954(0.3024) Grad: 1.6698  
Epoch: [1][9160/40998] Data 0.552 (0.553) Elapsed 130m 17s (remain 452m 49s) Loss: 0.1190(0.3023) Grad: 2.1619  
Epoch: [1][9180/40998] Data 0.553 (0.553) Elapsed 130m 35s (remain 452m 32s) Loss: 0.1444(0.3020) Grad: 3.1157  
Epoch: [1][9200/40998] Data 0.553 (0.553) Elapsed 130m 52s (remain 452m 15s) Loss: 0.1557(0.3019) Grad: 2.7366  
Epoch: [1][9220/40998] Data 0.552 (0.553) Elapsed 131m 9s (remain 451m 58s) Loss: 0.1108(0.3017) Grad: 1.9261  
Epoch: [1][9240/40998] Data 0.552 (0.553) Elapsed 131m 26s (remain 451m 41s) Loss: 0.1360(0.3014) Grad: 2.8110  
Epoch: [1][9260/40998] Data 0.553 (0.553) Elapsed 131m 43s (remain 451m 24s) Loss: 0.1250(0.3012) Grad: 2.0337  
Epoch: [1][9280/40998] Data 0.553 (0.553) Elapsed 132m 0s (remain 451m 7s) Loss: 0.1082(0.3010) Grad: 2.1692  
Epoch: [1][9300/40998] Data 0.553 (0.553) Elapsed 132m 17s (remain 450m 50s) Loss: 0.0820(0.3007) Grad: 1.3971  
Epoch: [1][9320/40998] Data 0.553 (0.553) Elapsed 132m 34s (remain 450m 32s) Loss: 0.1662(0.3006) Grad: 1.3420  
Epoch: [1][9340/40998] Data 0.552 (0.553) Elapsed 132m 51s (remain 450m 15s) Loss: 0.0946(0.3004) Grad: 1.7330  
Epoch: [1][9360/40998] Data 0.553 (0.553) Elapsed 133m 8s (remain 449m 58s) Loss: 0.4719(0.3004) Grad: 6.2861  
Epoch: [1][9380/40998] Data 0.552 (0.553) Elapsed 133m 25s (remain 449m 41s) Loss: 0.1829(0.3002) Grad: 2.3480  
Epoch: [1][9400/40998] Data 0.552 (0.553) Elapsed 133m 42s (remain 449m 24s) Loss: 0.2828(0.3001) Grad: 2.8431  
Epoch: [1][9420/40998] Data 0.552 (0.553) Elapsed 133m 59s (remain 449m 7s) Loss: 0.2431(0.3000) Grad: 3.3322  
Epoch: [1][9440/40998] Data 0.553 (0.553) Elapsed 134m 16s (remain 448m 50s) Loss: 0.0499(0.2997) Grad: 0.7240  
Epoch: [1][9460/40998] Data 0.553 (0.553) Elapsed 134m 33s (remain 448m 33s) Loss: 0.1715(0.2995) Grad: 2.2409  
Epoch: [1][9480/40998] Data 0.552 (0.553) Elapsed 134m 50s (remain 448m 16s) Loss: 0.3134(0.2995) Grad: 4.8003  
Epoch: [1][9500/40998] Data 0.553 (0.553) Elapsed 135m 8s (remain 447m 59s) Loss: 0.4222(0.2993) Grad: 5.7802  
Epoch: [1][9520/40998] Data 0.553 (0.553) Elapsed 135m 25s (remain 447m 41s) Loss: 0.1154(0.2990) Grad: 1.1436  
Epoch: [1][9540/40998] Data 0.553 (0.553) Elapsed 135m 42s (remain 447m 24s) Loss: 0.3694(0.2989) Grad: 3.3291  
Epoch: [1][9560/40998] Data 0.553 (0.553) Elapsed 135m 59s (remain 447m 7s) Loss: 0.2890(0.2987) Grad: 3.4256  
Epoch: [1][9580/40998] Data 0.553 (0.553) Elapsed 136m 16s (remain 446m 50s) Loss: 0.2402(0.2986) Grad: 3.1124  
Epoch: [1][9600/40998] Data 0.552 (0.553) Elapsed 136m 33s (remain 446m 33s) Loss: 0.1067(0.2984) Grad: 1.8211  
Epoch: [1][9620/40998] Data 0.553 (0.553) Elapsed 136m 50s (remain 446m 16s) Loss: 0.2682(0.2982) Grad: 3.0080  
Epoch: [1][9640/40998] Data 0.552 (0.553) Elapsed 137m 7s (remain 445m 59s) Loss: 0.3773(0.2979) Grad: 4.7305  
Epoch: [1][9660/40998] Data 0.553 (0.553) Elapsed 137m 24s (remain 445m 42s) Loss: 0.0954(0.2977) Grad: 2.7707  
Epoch: [1][9680/40998] Data 0.553 (0.553) Elapsed 137m 41s (remain 445m 25s) Loss: 0.2019(0.2975) Grad: 3.0861  
Epoch: [1][9700/40998] Data 0.553 (0.553) Elapsed 137m 58s (remain 445m 8s) Loss: 0.0632(0.2974) Grad: 1.4904  
Epoch: [1][9720/40998] Data 0.552 (0.553) Elapsed 138m 15s (remain 444m 51s) Loss: 0.1240(0.2972) Grad: 1.7863  
Epoch: [1][9740/40998] Data 0.553 (0.553) Elapsed 138m 32s (remain 444m 33s) Loss: 0.1565(0.2970) Grad: 2.2276  
Epoch: [1][9760/40998] Data 0.552 (0.553) Elapsed 138m 49s (remain 444m 16s) Loss: 0.1022(0.2968) Grad: 1.7565  
Epoch: [1][9780/40998] Data 0.553 (0.553) Elapsed 139m 6s (remain 443m 59s) Loss: 0.2700(0.2966) Grad: 3.1657  
Epoch: [1][9800/40998] Data 0.553 (0.553) Elapsed 139m 23s (remain 443m 42s) Loss: 0.1452(0.2963) Grad: 2.3832  
Epoch: [1][9820/40998] Data 0.553 (0.553) Elapsed 139m 40s (remain 443m 25s) Loss: 0.2919(0.2961) Grad: 3.4792  
Epoch: [1][9840/40998] Data 0.552 (0.553) Elapsed 139m 58s (remain 443m 8s) Loss: 0.6499(0.2961) Grad: 6.6629  
Epoch: [1][9860/40998] Data 0.552 (0.553) Elapsed 140m 15s (remain 442m 51s) Loss: 0.4106(0.2960) Grad: 6.1597  
Epoch: [1][9880/40998] Data 0.553 (0.553) Elapsed 140m 32s (remain 442m 34s) Loss: 0.1633(0.2958) Grad: 1.2065  
Epoch: [1][9900/40998] Data 0.553 (0.553) Elapsed 140m 49s (remain 442m 17s) Loss: 0.1014(0.2957) Grad: 1.2803  
Epoch: [1][9920/40998] Data 0.553 (0.553) Elapsed 141m 6s (remain 442m 0s) Loss: 0.7928(0.2955) Grad: 10.4774  
Epoch: [1][9940/40998] Data 0.552 (0.553) Elapsed 141m 23s (remain 441m 43s) Loss: 0.3218(0.2954) Grad: 3.6060  
Epoch: [1][9960/40998] Data 0.553 (0.553) Elapsed 141m 40s (remain 441m 26s) Loss: 0.5285(0.2953) Grad: 6.1914  
Epoch: [1][9980/40998] Data 0.553 (0.553) Elapsed 141m 57s (remain 441m 8s) Loss: 0.0488(0.2951) Grad: 0.7472  
Epoch: [1][10000/40998] Data 0.552 (0.553) Elapsed 142m 14s (remain 440m 51s) Loss: 0.1101(0.2949) Grad: 1.6955  
Epoch: [1][10020/40998] Data 0.553 (0.553) Elapsed 142m 31s (remain 440m 34s) Loss: 0.1418(0.2947) Grad: 2.3735  
Epoch: [1][10040/40998] Data 0.552 (0.553) Elapsed 142m 48s (remain 440m 17s) Loss: 0.1886(0.2945) Grad: 2.6143  
Epoch: [1][10060/40998] Data 0.553 (0.553) Elapsed 143m 5s (remain 440m 0s) Loss: 0.1682(0.2943) Grad: 3.1873  
Epoch: [1][10080/40998] Data 0.553 (0.553) Elapsed 143m 22s (remain 439m 43s) Loss: 0.1563(0.2941) Grad: 1.9813  
Epoch: [1][10100/40998] Data 0.552 (0.553) Elapsed 143m 39s (remain 439m 26s) Loss: 0.1523(0.2940) Grad: 3.3176  
Epoch: [1][10120/40998] Data 0.553 (0.553) Elapsed 143m 56s (remain 439m 9s) Loss: 0.1653(0.2939) Grad: 2.4759  
Epoch: [1][10140/40998] Data 0.553 (0.553) Elapsed 144m 13s (remain 438m 52s) Loss: 0.2374(0.2937) Grad: 2.8077  
Epoch: [1][10160/40998] Data 0.553 (0.553) Elapsed 144m 31s (remain 438m 35s) Loss: 0.1224(0.2936) Grad: 1.2285  
Epoch: [1][10180/40998] Data 0.552 (0.553) Elapsed 144m 48s (remain 438m 18s) Loss: 0.4042(0.2934) Grad: 4.0073  
Epoch: [1][10200/40998] Data 0.554 (0.553) Elapsed 145m 5s (remain 438m 1s) Loss: 0.4165(0.2933) Grad: 5.7505  
Epoch: [1][10220/40998] Data 0.553 (0.553) Elapsed 145m 22s (remain 437m 44s) Loss: 0.1038(0.2931) Grad: 1.4230  
Epoch: [1][10240/40998] Data 0.553 (0.553) Elapsed 145m 39s (remain 437m 26s) Loss: 0.0256(0.2929) Grad: 0.3938  
Epoch: [1][10260/40998] Data 0.552 (0.553) Elapsed 145m 56s (remain 437m 9s) Loss: 0.2263(0.2928) Grad: 3.2504  
Epoch: [1][10280/40998] Data 0.553 (0.553) Elapsed 146m 13s (remain 436m 52s) Loss: 0.5224(0.2927) Grad: 4.2521  
Epoch: [1][10300/40998] Data 0.553 (0.553) Elapsed 146m 30s (remain 436m 35s) Loss: 0.2019(0.2925) Grad: 1.6706  
Epoch: [1][10320/40998] Data 0.553 (0.553) Elapsed 146m 47s (remain 436m 18s) Loss: 0.3010(0.2924) Grad: 3.1065  
Epoch: [1][10340/40998] Data 0.553 (0.553) Elapsed 147m 4s (remain 436m 1s) Loss: 0.0759(0.2922) Grad: 0.8501  
Epoch: [1][10360/40998] Data 0.553 (0.553) Elapsed 147m 21s (remain 435m 44s) Loss: 0.4984(0.2920) Grad: 5.3230  
Epoch: [1][10380/40998] Data 0.553 (0.553) Elapsed 147m 38s (remain 435m 27s) Loss: 0.1216(0.2918) Grad: 2.1360  
Epoch: [1][10400/40998] Data 0.553 (0.553) Elapsed 147m 55s (remain 435m 10s) Loss: 0.1922(0.2917) Grad: 2.5480  
Epoch: [1][10420/40998] Data 0.553 (0.553) Elapsed 148m 12s (remain 434m 53s) Loss: 0.1770(0.2915) Grad: 2.7236  
Epoch: [1][10440/40998] Data 0.553 (0.553) Elapsed 148m 29s (remain 434m 36s) Loss: 0.1676(0.2913) Grad: 2.0044  
Epoch: [1][10460/40998] Data 0.552 (0.553) Elapsed 148m 46s (remain 434m 19s) Loss: 0.6212(0.2912) Grad: 6.1231  
Epoch: [1][10480/40998] Data 0.553 (0.553) Elapsed 149m 4s (remain 434m 1s) Loss: 0.4213(0.2909) Grad: 3.6924  
Epoch: [1][10500/40998] Data 0.552 (0.553) Elapsed 149m 21s (remain 433m 44s) Loss: 0.2666(0.2907) Grad: 4.4078  
Epoch: [1][10520/40998] Data 0.552 (0.553) Elapsed 149m 38s (remain 433m 27s) Loss: 0.0554(0.2906) Grad: 0.9174  
Epoch: [1][10540/40998] Data 0.553 (0.553) Elapsed 149m 55s (remain 433m 10s) Loss: 0.2599(0.2904) Grad: 3.5663  
Epoch: [1][10560/40998] Data 0.553 (0.553) Elapsed 150m 12s (remain 432m 53s) Loss: 0.0923(0.2903) Grad: 1.8463  
Epoch: [1][10580/40998] Data 0.553 (0.553) Elapsed 150m 29s (remain 432m 36s) Loss: 0.1121(0.2901) Grad: 1.5749  
Epoch: [1][10600/40998] Data 0.553 (0.553) Elapsed 150m 46s (remain 432m 19s) Loss: 0.0971(0.2899) Grad: 0.8963  
Epoch: [1][10620/40998] Data 0.553 (0.553) Elapsed 151m 3s (remain 432m 2s) Loss: 0.3209(0.2898) Grad: 4.1762  
Epoch: [1][10640/40998] Data 0.552 (0.553) Elapsed 151m 20s (remain 431m 45s) Loss: 0.3506(0.2896) Grad: 3.3387  
Epoch: [1][10660/40998] Data 0.554 (0.553) Elapsed 151m 37s (remain 431m 28s) Loss: 0.3863(0.2895) Grad: 3.9963  
Epoch: [1][10680/40998] Data 0.553 (0.553) Elapsed 151m 54s (remain 431m 11s) Loss: 0.2034(0.2894) Grad: 2.1361  
Epoch: [1][10700/40998] Data 0.553 (0.553) Elapsed 152m 11s (remain 430m 54s) Loss: 0.1616(0.2892) Grad: 1.7826  
Epoch: [1][10720/40998] Data 0.552 (0.553) Elapsed 152m 28s (remain 430m 36s) Loss: 0.4334(0.2891) Grad: 3.1840  
Epoch: [1][10740/40998] Data 0.553 (0.553) Elapsed 152m 45s (remain 430m 19s) Loss: 0.0795(0.2889) Grad: 1.6903  
Epoch: [1][10760/40998] Data 0.553 (0.553) Elapsed 153m 2s (remain 430m 2s) Loss: 0.4079(0.2888) Grad: 3.5098  
Epoch: [1][10780/40998] Data 0.553 (0.553) Elapsed 153m 19s (remain 429m 45s) Loss: 0.1234(0.2885) Grad: 1.7600  
Epoch: [1][10800/40998] Data 0.552 (0.553) Elapsed 153m 37s (remain 429m 28s) Loss: 0.1356(0.2883) Grad: 1.7237  
Epoch: [1][10820/40998] Data 0.553 (0.553) Elapsed 153m 54s (remain 429m 11s) Loss: 0.3968(0.2882) Grad: 5.0036  
Epoch: [1][10840/40998] Data 0.552 (0.553) Elapsed 154m 11s (remain 428m 54s) Loss: 0.1598(0.2881) Grad: 2.9145  
Epoch: [1][10860/40998] Data 0.552 (0.553) Elapsed 154m 28s (remain 428m 37s) Loss: 0.6568(0.2879) Grad: 4.0339  
Epoch: [1][10880/40998] Data 0.553 (0.553) Elapsed 154m 45s (remain 428m 20s) Loss: 0.1365(0.2879) Grad: 2.0833  
Epoch: [1][10900/40998] Data 0.553 (0.553) Elapsed 155m 2s (remain 428m 3s) Loss: 0.0967(0.2877) Grad: 1.5512  
Epoch: [1][10920/40998] Data 0.553 (0.553) Elapsed 155m 19s (remain 427m 46s) Loss: 0.2174(0.2875) Grad: 2.2061  
Epoch: [1][10940/40998] Data 0.553 (0.553) Elapsed 155m 36s (remain 427m 29s) Loss: 0.1685(0.2874) Grad: 2.4892  
Epoch: [1][10960/40998] Data 0.552 (0.553) Elapsed 155m 53s (remain 427m 11s) Loss: 0.1538(0.2872) Grad: 2.3856  
Epoch: [1][10980/40998] Data 0.553 (0.553) Elapsed 156m 10s (remain 426m 54s) Loss: 0.0939(0.2871) Grad: 1.4951  
Epoch: [1][11000/40998] Data 0.553 (0.553) Elapsed 156m 27s (remain 426m 37s) Loss: 0.2898(0.2870) Grad: 3.3033  
Epoch: [1][11020/40998] Data 0.553 (0.553) Elapsed 156m 44s (remain 426m 20s) Loss: 0.3718(0.2868) Grad: 2.7753  
Epoch: [1][11040/40998] Data 0.553 (0.553) Elapsed 157m 1s (remain 426m 3s) Loss: 0.1946(0.2866) Grad: 2.3912  
Epoch: [1][11060/40998] Data 0.552 (0.553) Elapsed 157m 18s (remain 425m 46s) Loss: 0.3590(0.2865) Grad: 4.4216  
Epoch: [1][11080/40998] Data 0.553 (0.553) Elapsed 157m 35s (remain 425m 29s) Loss: 0.2536(0.2864) Grad: 2.8354  
Epoch: [1][11100/40998] Data 0.553 (0.553) Elapsed 157m 52s (remain 425m 12s) Loss: 0.3320(0.2862) Grad: 4.4136  
Epoch: [1][11120/40998] Data 0.553 (0.553) Elapsed 158m 9s (remain 424m 55s) Loss: 0.1996(0.2860) Grad: 1.6666  
Epoch: [1][11140/40998] Data 0.553 (0.553) Elapsed 158m 27s (remain 424m 38s) Loss: 0.2161(0.2858) Grad: 2.5799  
Epoch: [1][11160/40998] Data 0.553 (0.553) Elapsed 158m 44s (remain 424m 21s) Loss: 0.3692(0.2857) Grad: 3.3311  
Epoch: [1][11180/40998] Data 0.553 (0.553) Elapsed 159m 1s (remain 424m 3s) Loss: 0.1598(0.2856) Grad: 1.2819  
Epoch: [1][11200/40998] Data 0.553 (0.553) Elapsed 159m 18s (remain 423m 46s) Loss: 0.2326(0.2854) Grad: 1.8353  
Epoch: [1][11220/40998] Data 0.552 (0.553) Elapsed 159m 35s (remain 423m 29s) Loss: 0.1767(0.2853) Grad: 3.0297  
Epoch: [1][11240/40998] Data 0.553 (0.553) Elapsed 159m 52s (remain 423m 12s) Loss: 0.1292(0.2851) Grad: 1.7293  
Epoch: [1][11260/40998] Data 0.553 (0.553) Elapsed 160m 9s (remain 422m 55s) Loss: 0.0568(0.2849) Grad: 1.2657  
Epoch: [1][11280/40998] Data 0.553 (0.553) Elapsed 160m 26s (remain 422m 38s) Loss: 0.1996(0.2847) Grad: 3.1683  
Epoch: [1][11300/40998] Data 0.553 (0.553) Elapsed 160m 43s (remain 422m 21s) Loss: 0.3915(0.2846) Grad: 4.3228  
Epoch: [1][11320/40998] Data 0.552 (0.553) Elapsed 161m 0s (remain 422m 4s) Loss: 0.3222(0.2845) Grad: 3.2798  
Epoch: [1][11340/40998] Data 0.553 (0.553) Elapsed 161m 17s (remain 421m 47s) Loss: 0.1508(0.2844) Grad: 1.8556  
Epoch: [1][11360/40998] Data 0.553 (0.553) Elapsed 161m 34s (remain 421m 30s) Loss: 0.0613(0.2842) Grad: 0.7567  
Epoch: [1][11380/40998] Data 0.553 (0.553) Elapsed 161m 51s (remain 421m 13s) Loss: 0.1617(0.2841) Grad: 2.5667  
Epoch: [1][11400/40998] Data 0.553 (0.553) Elapsed 162m 8s (remain 420m 56s) Loss: 0.4391(0.2841) Grad: 4.3833  
Epoch: [1][11420/40998] Data 0.553 (0.553) Elapsed 162m 25s (remain 420m 38s) Loss: 0.2194(0.2840) Grad: 2.5412  
Epoch: [1][11440/40998] Data 0.553 (0.553) Elapsed 162m 42s (remain 420m 21s) Loss: 0.3348(0.2839) Grad: 4.1984  
Epoch: [1][11460/40998] Data 0.552 (0.553) Elapsed 163m 0s (remain 420m 4s) Loss: 0.1438(0.2838) Grad: 1.3391  
Epoch: [1][11480/40998] Data 0.553 (0.553) Elapsed 163m 17s (remain 419m 47s) Loss: 0.0572(0.2836) Grad: 0.9536  
Epoch: [1][11500/40998] Data 0.553 (0.553) Elapsed 163m 34s (remain 419m 30s) Loss: 0.2534(0.2834) Grad: 3.4075  
Epoch: [1][11520/40998] Data 0.553 (0.553) Elapsed 163m 51s (remain 419m 13s) Loss: 0.2573(0.2833) Grad: 3.6560  
Epoch: [1][11540/40998] Data 0.553 (0.553) Elapsed 164m 8s (remain 418m 56s) Loss: 0.0579(0.2832) Grad: 1.1304  
Epoch: [1][11560/40998] Data 0.553 (0.553) Elapsed 164m 25s (remain 418m 39s) Loss: 0.1019(0.2831) Grad: 1.6954  
Epoch: [1][11580/40998] Data 0.553 (0.553) Elapsed 164m 42s (remain 418m 22s) Loss: 0.2979(0.2830) Grad: 5.3529  
Epoch: [1][11600/40998] Data 0.553 (0.553) Elapsed 164m 59s (remain 418m 5s) Loss: 0.2278(0.2828) Grad: 1.8658  
Epoch: [1][11620/40998] Data 0.553 (0.553) Elapsed 165m 16s (remain 417m 48s) Loss: 0.1379(0.2827) Grad: 2.0557  
Epoch: [1][11640/40998] Data 0.553 (0.553) Elapsed 165m 33s (remain 417m 31s) Loss: 0.2468(0.2826) Grad: 2.8676  
Epoch: [1][11660/40998] Data 0.553 (0.553) Elapsed 165m 50s (remain 417m 13s) Loss: 0.4265(0.2825) Grad: 5.6323  
Epoch: [1][11680/40998] Data 0.553 (0.553) Elapsed 166m 7s (remain 416m 56s) Loss: 0.1056(0.2823) Grad: 1.2849  
Epoch: [1][11700/40998] Data 0.552 (0.553) Elapsed 166m 24s (remain 416m 39s) Loss: 0.1276(0.2821) Grad: 1.8242  
Epoch: [1][11720/40998] Data 0.554 (0.553) Elapsed 166m 41s (remain 416m 22s) Loss: 0.2268(0.2820) Grad: 3.7482  
Epoch: [1][11740/40998] Data 0.553 (0.553) Elapsed 166m 58s (remain 416m 5s) Loss: 0.3198(0.2819) Grad: 3.9428  
Epoch: [1][11760/40998] Data 0.552 (0.553) Elapsed 167m 15s (remain 415m 48s) Loss: 0.4441(0.2817) Grad: 5.2166  
Epoch: [1][11780/40998] Data 0.553 (0.553) Elapsed 167m 32s (remain 415m 31s) Loss: 0.1980(0.2816) Grad: 2.5334  
Epoch: [1][11800/40998] Data 0.552 (0.553) Elapsed 167m 50s (remain 415m 14s) Loss: 0.1651(0.2815) Grad: 2.0166  
Epoch: [1][11820/40998] Data 0.553 (0.553) Elapsed 168m 7s (remain 414m 57s) Loss: 0.1535(0.2814) Grad: 2.6256  
Epoch: [1][11840/40998] Data 0.552 (0.553) Elapsed 168m 24s (remain 414m 40s) Loss: 0.3905(0.2813) Grad: 4.1221  
Epoch: [1][11860/40998] Data 0.553 (0.553) Elapsed 168m 41s (remain 414m 23s) Loss: 0.2685(0.2810) Grad: 3.0302  
Epoch: [1][11880/40998] Data 0.553 (0.553) Elapsed 168m 58s (remain 414m 6s) Loss: 0.2195(0.2809) Grad: 1.5101  
Epoch: [1][11900/40998] Data 0.553 (0.553) Elapsed 169m 15s (remain 413m 49s) Loss: 0.1416(0.2807) Grad: 1.8085  
Epoch: [1][11920/40998] Data 0.552 (0.553) Elapsed 169m 32s (remain 413m 31s) Loss: 0.1171(0.2805) Grad: 2.8623  
Epoch: [1][11940/40998] Data 0.553 (0.553) Elapsed 169m 49s (remain 413m 14s) Loss: 0.1078(0.2804) Grad: 2.1971  
Epoch: [1][11960/40998] Data 0.553 (0.553) Elapsed 170m 6s (remain 412m 57s) Loss: 0.2241(0.2802) Grad: 3.0881  
Epoch: [1][11980/40998] Data 0.552 (0.553) Elapsed 170m 23s (remain 412m 40s) Loss: 0.1337(0.2801) Grad: 3.1350  
Epoch: [1][12000/40998] Data 0.553 (0.553) Elapsed 170m 40s (remain 412m 23s) Loss: 0.4619(0.2800) Grad: 4.3302  
Epoch: [1][12020/40998] Data 0.553 (0.553) Elapsed 170m 57s (remain 412m 6s) Loss: 0.2204(0.2799) Grad: 2.9538  
Epoch: [1][12040/40998] Data 0.553 (0.553) Elapsed 171m 14s (remain 411m 49s) Loss: 0.0712(0.2797) Grad: 0.8668  
Epoch: [1][12060/40998] Data 0.553 (0.553) Elapsed 171m 31s (remain 411m 32s) Loss: 0.3250(0.2798) Grad: 4.2233  
Epoch: [1][12080/40998] Data 0.553 (0.553) Elapsed 171m 48s (remain 411m 15s) Loss: 0.3344(0.2797) Grad: 2.8720  
Epoch: [1][12100/40998] Data 0.552 (0.553) Elapsed 172m 5s (remain 410m 58s) Loss: 0.1912(0.2796) Grad: 2.7060  
Epoch: [1][12120/40998] Data 0.552 (0.553) Elapsed 172m 23s (remain 410m 41s) Loss: 0.1051(0.2794) Grad: 1.0307  
Epoch: [1][12140/40998] Data 0.553 (0.553) Elapsed 172m 40s (remain 410m 24s) Loss: 0.4672(0.2794) Grad: 3.4981  
Epoch: [1][12160/40998] Data 0.553 (0.553) Elapsed 172m 57s (remain 410m 6s) Loss: 0.2009(0.2793) Grad: 2.7937  
Epoch: [1][12180/40998] Data 0.552 (0.553) Elapsed 173m 14s (remain 409m 49s) Loss: 0.1070(0.2791) Grad: 1.8084  
Epoch: [1][12200/40998] Data 0.553 (0.553) Elapsed 173m 31s (remain 409m 32s) Loss: 0.0850(0.2790) Grad: 1.2566  
Epoch: [1][12220/40998] Data 0.553 (0.553) Elapsed 173m 48s (remain 409m 15s) Loss: 0.0818(0.2789) Grad: 1.2756  
Epoch: [1][12240/40998] Data 0.553 (0.553) Elapsed 174m 5s (remain 408m 58s) Loss: 0.1323(0.2787) Grad: 3.0956  
Epoch: [1][12260/40998] Data 0.553 (0.553) Elapsed 174m 22s (remain 408m 41s) Loss: 0.0608(0.2785) Grad: 1.1836  
Epoch: [1][12280/40998] Data 0.553 (0.553) Elapsed 174m 39s (remain 408m 24s) Loss: 0.5868(0.2785) Grad: 5.6230  
Epoch: [1][12300/40998] Data 0.553 (0.553) Elapsed 174m 56s (remain 408m 7s) Loss: 0.1592(0.2784) Grad: 2.2664  
Epoch: [1][12320/40998] Data 0.553 (0.553) Elapsed 175m 13s (remain 407m 50s) Loss: 0.1328(0.2782) Grad: 1.7508  
Epoch: [1][12340/40998] Data 0.553 (0.553) Elapsed 175m 30s (remain 407m 33s) Loss: 0.0962(0.2780) Grad: 1.9179  
Epoch: [1][12360/40998] Data 0.553 (0.553) Elapsed 175m 47s (remain 407m 16s) Loss: 0.0645(0.2779) Grad: 0.8233  
Epoch: [1][12380/40998] Data 0.553 (0.553) Elapsed 176m 4s (remain 406m 59s) Loss: 0.1937(0.2777) Grad: 2.9432  
Epoch: [1][12400/40998] Data 0.552 (0.553) Elapsed 176m 21s (remain 406m 42s) Loss: 0.2210(0.2776) Grad: 3.5599  
Epoch: [1][12420/40998] Data 0.553 (0.553) Elapsed 176m 38s (remain 406m 24s) Loss: 0.0612(0.2775) Grad: 1.1432  
Epoch: [1][12440/40998] Data 0.552 (0.553) Elapsed 176m 55s (remain 406m 7s) Loss: 0.2375(0.2773) Grad: 3.4121  
Epoch: [1][12460/40998] Data 0.553 (0.553) Elapsed 177m 13s (remain 405m 50s) Loss: 0.3726(0.2772) Grad: 2.5110  
Epoch: [1][12480/40998] Data 0.553 (0.553) Elapsed 177m 30s (remain 405m 33s) Loss: 0.2202(0.2770) Grad: 3.1691  
Epoch: [1][12500/40998] Data 0.553 (0.553) Elapsed 177m 47s (remain 405m 16s) Loss: 0.0823(0.2768) Grad: 1.2501  
Epoch: [1][12520/40998] Data 0.553 (0.553) Elapsed 178m 4s (remain 404m 59s) Loss: 0.0437(0.2766) Grad: 0.8163  
Epoch: [1][12540/40998] Data 0.552 (0.553) Elapsed 178m 21s (remain 404m 42s) Loss: 0.1465(0.2765) Grad: 1.9858  
Epoch: [1][12560/40998] Data 0.553 (0.553) Elapsed 178m 38s (remain 404m 25s) Loss: 0.3101(0.2763) Grad: 4.8959  
Epoch: [1][12580/40998] Data 0.553 (0.553) Elapsed 178m 55s (remain 404m 8s) Loss: 0.1404(0.2762) Grad: 1.5938  
Epoch: [1][12600/40998] Data 0.553 (0.553) Elapsed 179m 12s (remain 403m 51s) Loss: 0.2195(0.2761) Grad: 2.4182  
Epoch: [1][12620/40998] Data 0.552 (0.553) Elapsed 179m 29s (remain 403m 34s) Loss: 0.1222(0.2759) Grad: 2.2186  
Epoch: [1][12640/40998] Data 0.552 (0.553) Elapsed 179m 46s (remain 403m 17s) Loss: 0.1959(0.2757) Grad: 2.9698  
Epoch: [1][12660/40998] Data 0.553 (0.553) Elapsed 180m 3s (remain 403m 0s) Loss: 0.1455(0.2756) Grad: 3.3279  
Epoch: [1][12680/40998] Data 0.553 (0.553) Elapsed 180m 20s (remain 402m 42s) Loss: 0.4566(0.2755) Grad: 5.2471  
Epoch: [1][12700/40998] Data 0.553 (0.553) Elapsed 180m 37s (remain 402m 25s) Loss: 0.1578(0.2753) Grad: 2.5220  
Epoch: [1][12720/40998] Data 0.552 (0.553) Elapsed 180m 54s (remain 402m 8s) Loss: 0.1604(0.2753) Grad: 1.4084  
Epoch: [1][12740/40998] Data 0.553 (0.553) Elapsed 181m 11s (remain 401m 51s) Loss: 0.3172(0.2751) Grad: 3.1501  
Epoch: [1][12760/40998] Data 0.553 (0.553) Elapsed 181m 28s (remain 401m 34s) Loss: 0.1468(0.2749) Grad: 2.8432  
Epoch: [1][12780/40998] Data 0.553 (0.553) Elapsed 181m 46s (remain 401m 17s) Loss: 0.2035(0.2747) Grad: 1.9159  
Epoch: [1][12800/40998] Data 0.553 (0.553) Elapsed 182m 3s (remain 401m 0s) Loss: 0.1491(0.2747) Grad: 2.8245  
Epoch: [1][12820/40998] Data 0.553 (0.553) Elapsed 182m 20s (remain 400m 43s) Loss: 0.1649(0.2745) Grad: 1.7007  
Epoch: [1][12840/40998] Data 0.553 (0.553) Elapsed 182m 37s (remain 400m 26s) Loss: 0.5324(0.2745) Grad: 6.4621  
Epoch: [1][12860/40998] Data 0.553 (0.553) Elapsed 182m 54s (remain 400m 9s) Loss: 0.1526(0.2744) Grad: 2.3720  
Epoch: [1][12880/40998] Data 0.553 (0.553) Elapsed 183m 11s (remain 399m 52s) Loss: 0.3317(0.2743) Grad: 3.5787  
Epoch: [1][12900/40998] Data 0.552 (0.553) Elapsed 183m 28s (remain 399m 35s) Loss: 0.4026(0.2744) Grad: 3.1777  
Epoch: [1][12920/40998] Data 0.553 (0.553) Elapsed 183m 45s (remain 399m 18s) Loss: 0.2084(0.2742) Grad: 1.9413  
Epoch: [1][12940/40998] Data 0.553 (0.553) Elapsed 184m 2s (remain 399m 0s) Loss: 0.1895(0.2742) Grad: 1.7200  
Epoch: [1][12960/40998] Data 0.553 (0.553) Elapsed 184m 19s (remain 398m 43s) Loss: 0.0665(0.2741) Grad: 1.7799  
Epoch: [1][12980/40998] Data 0.554 (0.553) Elapsed 184m 36s (remain 398m 26s) Loss: 0.2867(0.2740) Grad: 2.7899  
Epoch: [1][13000/40998] Data 0.552 (0.553) Elapsed 184m 53s (remain 398m 9s) Loss: 0.3508(0.2738) Grad: 4.1978  
Epoch: [1][13020/40998] Data 0.553 (0.553) Elapsed 185m 10s (remain 397m 52s) Loss: 0.1074(0.2737) Grad: 1.8525  
Epoch: [1][13040/40998] Data 0.552 (0.553) Elapsed 185m 27s (remain 397m 35s) Loss: 0.1009(0.2736) Grad: 1.8458  
Epoch: [1][13060/40998] Data 0.553 (0.553) Elapsed 185m 44s (remain 397m 18s) Loss: 0.0428(0.2734) Grad: 0.7681  
Epoch: [1][13080/40998] Data 0.553 (0.553) Elapsed 186m 1s (remain 397m 1s) Loss: 0.5188(0.2733) Grad: 4.2225  
Epoch: [1][13100/40998] Data 0.552 (0.553) Elapsed 186m 19s (remain 396m 44s) Loss: 0.1424(0.2731) Grad: 2.8646  
Epoch: [1][13120/40998] Data 0.553 (0.553) Elapsed 186m 36s (remain 396m 27s) Loss: 0.1236(0.2730) Grad: 1.4198  
Epoch: [1][13140/40998] Data 0.552 (0.553) Elapsed 186m 53s (remain 396m 10s) Loss: 0.2005(0.2730) Grad: 2.7191  
Epoch: [1][13160/40998] Data 0.553 (0.553) Elapsed 187m 10s (remain 395m 53s) Loss: 0.1143(0.2728) Grad: 1.9377  
Epoch: [1][13180/40998] Data 0.552 (0.553) Elapsed 187m 27s (remain 395m 36s) Loss: 0.0813(0.2727) Grad: 1.9954  
Epoch: [1][13200/40998] Data 0.553 (0.553) Elapsed 187m 44s (remain 395m 18s) Loss: 0.2514(0.2725) Grad: 5.0005  
Epoch: [1][13220/40998] Data 0.553 (0.553) Elapsed 188m 1s (remain 395m 1s) Loss: 0.1540(0.2723) Grad: 3.1987  
Epoch: [1][13240/40998] Data 0.553 (0.553) Elapsed 188m 18s (remain 394m 44s) Loss: 0.1607(0.2723) Grad: 2.1414  
Epoch: [1][13260/40998] Data 0.553 (0.553) Elapsed 188m 35s (remain 394m 27s) Loss: 0.0780(0.2722) Grad: 1.3876  
Epoch: [1][13280/40998] Data 0.553 (0.553) Elapsed 188m 52s (remain 394m 10s) Loss: 0.0738(0.2720) Grad: 1.4101  
Epoch: [1][13300/40998] Data 0.553 (0.553) Elapsed 189m 9s (remain 393m 53s) Loss: 0.2876(0.2718) Grad: 2.8325  
Epoch: [1][13320/40998] Data 0.553 (0.553) Elapsed 189m 26s (remain 393m 36s) Loss: 0.1924(0.2717) Grad: 1.5424  
Epoch: [1][13340/40998] Data 0.553 (0.553) Elapsed 189m 43s (remain 393m 19s) Loss: 0.4726(0.2716) Grad: 3.5446  
Epoch: [1][13360/40998] Data 0.553 (0.553) Elapsed 190m 0s (remain 393m 2s) Loss: 0.0591(0.2714) Grad: 1.6184  
Epoch: [1][13380/40998] Data 0.553 (0.553) Elapsed 190m 17s (remain 392m 45s) Loss: 0.1623(0.2714) Grad: 2.6068  
Epoch: [1][13400/40998] Data 0.553 (0.553) Elapsed 190m 34s (remain 392m 28s) Loss: 0.3665(0.2712) Grad: 3.1257  
Epoch: [1][13420/40998] Data 0.553 (0.553) Elapsed 190m 52s (remain 392m 11s) Loss: 0.2133(0.2710) Grad: 2.2932  
Epoch: [1][13440/40998] Data 0.553 (0.553) Elapsed 191m 9s (remain 391m 54s) Loss: 0.5188(0.2709) Grad: 4.5501  
Epoch: [1][13460/40998] Data 0.553 (0.553) Elapsed 191m 26s (remain 391m 37s) Loss: 0.0752(0.2707) Grad: 1.2922  
Epoch: [1][13480/40998] Data 0.553 (0.553) Elapsed 191m 43s (remain 391m 19s) Loss: 0.1093(0.2706) Grad: 1.7965  
Epoch: [1][13500/40998] Data 0.553 (0.553) Elapsed 192m 0s (remain 391m 2s) Loss: 0.0759(0.2705) Grad: 1.4157  
Epoch: [1][13520/40998] Data 0.553 (0.553) Elapsed 192m 17s (remain 390m 45s) Loss: 0.1198(0.2704) Grad: 1.5453  
Epoch: [1][13540/40998] Data 0.552 (0.553) Elapsed 192m 34s (remain 390m 28s) Loss: 0.2351(0.2702) Grad: 2.8289  
Epoch: [1][13560/40998] Data 0.553 (0.553) Elapsed 192m 51s (remain 390m 11s) Loss: 0.0649(0.2701) Grad: 1.0979  
Epoch: [1][13580/40998] Data 0.553 (0.553) Elapsed 193m 8s (remain 389m 54s) Loss: 0.2389(0.2699) Grad: 2.1562  
Epoch: [1][13600/40998] Data 0.553 (0.553) Elapsed 193m 25s (remain 389m 37s) Loss: 0.0966(0.2698) Grad: 1.6538  
Epoch: [1][13620/40998] Data 0.553 (0.553) Elapsed 193m 42s (remain 389m 20s) Loss: 0.0821(0.2696) Grad: 1.9282  
Epoch: [1][13640/40998] Data 0.552 (0.553) Elapsed 193m 59s (remain 389m 3s) Loss: 0.2066(0.2695) Grad: 3.0901  
Epoch: [1][13660/40998] Data 0.553 (0.553) Elapsed 194m 16s (remain 388m 46s) Loss: 0.1418(0.2693) Grad: 2.2154  
Epoch: [1][13680/40998] Data 0.553 (0.553) Elapsed 194m 33s (remain 388m 29s) Loss: 0.3653(0.2692) Grad: 4.4659  
Epoch: [1][13700/40998] Data 0.553 (0.553) Elapsed 194m 50s (remain 388m 12s) Loss: 0.2434(0.2691) Grad: 3.1364  
Epoch: [1][13720/40998] Data 0.553 (0.553) Elapsed 195m 7s (remain 387m 55s) Loss: 0.1268(0.2690) Grad: 1.1831  
Epoch: [1][13740/40998] Data 0.552 (0.553) Elapsed 195m 24s (remain 387m 38s) Loss: 0.1555(0.2689) Grad: 1.7581  
Epoch: [1][13760/40998] Data 0.553 (0.553) Elapsed 195m 42s (remain 387m 20s) Loss: 0.0566(0.2688) Grad: 1.1725  
Epoch: [1][13780/40998] Data 0.553 (0.553) Elapsed 195m 59s (remain 387m 3s) Loss: 0.0174(0.2687) Grad: 0.2886  
Epoch: [1][13800/40998] Data 0.552 (0.553) Elapsed 196m 16s (remain 386m 46s) Loss: 0.2373(0.2686) Grad: 2.7507  
Epoch: [1][13820/40998] Data 0.553 (0.553) Elapsed 196m 33s (remain 386m 29s) Loss: 0.4093(0.2685) Grad: 2.8707  
Epoch: [1][13840/40998] Data 0.553 (0.553) Elapsed 196m 50s (remain 386m 12s) Loss: 0.1244(0.2684) Grad: 2.1606  
Epoch: [1][13860/40998] Data 0.552 (0.553) Elapsed 197m 7s (remain 385m 55s) Loss: 0.1576(0.2682) Grad: 1.5257  
Epoch: [1][13880/40998] Data 0.553 (0.553) Elapsed 197m 24s (remain 385m 38s) Loss: 0.3625(0.2681) Grad: 6.8079  
Epoch: [1][13900/40998] Data 0.553 (0.553) Elapsed 197m 41s (remain 385m 21s) Loss: 0.2467(0.2680) Grad: 2.1558  
Epoch: [1][13920/40998] Data 0.553 (0.553) Elapsed 197m 58s (remain 385m 4s) Loss: 0.1414(0.2679) Grad: 2.4116  
Epoch: [1][13940/40998] Data 0.552 (0.553) Elapsed 198m 15s (remain 384m 47s) Loss: 0.1597(0.2678) Grad: 2.4256  
Epoch: [1][13960/40998] Data 0.553 (0.553) Elapsed 198m 32s (remain 384m 30s) Loss: 0.1836(0.2677) Grad: 1.8756  
Epoch: [1][13980/40998] Data 0.553 (0.553) Elapsed 198m 49s (remain 384m 13s) Loss: 0.1919(0.2675) Grad: 2.7549  
Epoch: [1][14000/40998] Data 0.553 (0.553) Elapsed 199m 6s (remain 383m 56s) Loss: 0.0990(0.2675) Grad: 1.5303  
Epoch: [1][14020/40998] Data 0.552 (0.553) Elapsed 199m 23s (remain 383m 38s) Loss: 0.2932(0.2674) Grad: 2.5059  
Epoch: [1][14040/40998] Data 0.553 (0.553) Elapsed 199m 40s (remain 383m 21s) Loss: 0.6504(0.2673) Grad: 4.8618  
Epoch: [1][14060/40998] Data 0.553 (0.553) Elapsed 199m 57s (remain 383m 4s) Loss: 0.0658(0.2671) Grad: 1.0269  
Epoch: [1][14080/40998] Data 0.552 (0.553) Elapsed 200m 15s (remain 382m 47s) Loss: 0.1637(0.2669) Grad: 3.0674  
Epoch: [1][14100/40998] Data 0.553 (0.553) Elapsed 200m 32s (remain 382m 30s) Loss: 0.3702(0.2668) Grad: 3.6453  
Epoch: [1][14120/40998] Data 0.553 (0.553) Elapsed 200m 49s (remain 382m 13s) Loss: 0.0752(0.2667) Grad: 1.3577  
Epoch: [1][14140/40998] Data 0.553 (0.553) Elapsed 201m 6s (remain 381m 56s) Loss: 0.1755(0.2665) Grad: 2.7587  
Epoch: [1][14160/40998] Data 0.553 (0.553) Elapsed 201m 23s (remain 381m 39s) Loss: 0.4580(0.2665) Grad: 5.6603  
Epoch: [1][14180/40998] Data 0.553 (0.553) Elapsed 201m 40s (remain 381m 22s) Loss: 0.0533(0.2664) Grad: 0.6712  
Epoch: [1][14200/40998] Data 0.552 (0.553) Elapsed 201m 57s (remain 381m 5s) Loss: 0.2404(0.2662) Grad: 2.2382  
Epoch: [1][14220/40998] Data 0.553 (0.553) Elapsed 202m 14s (remain 380m 48s) Loss: 0.1961(0.2661) Grad: 2.9530  
Epoch: [1][14240/40998] Data 0.553 (0.553) Elapsed 202m 31s (remain 380m 31s) Loss: 0.2574(0.2660) Grad: 2.5785  
Epoch: [1][14260/40998] Data 0.553 (0.553) Elapsed 202m 48s (remain 380m 14s) Loss: 0.2457(0.2659) Grad: 3.2384  
Epoch: [1][14280/40998] Data 0.553 (0.553) Elapsed 203m 5s (remain 379m 57s) Loss: 0.1459(0.2658) Grad: 2.2444  
Epoch: [1][14300/40998] Data 0.553 (0.553) Elapsed 203m 22s (remain 379m 40s) Loss: 0.0841(0.2657) Grad: 1.2871  
Epoch: [1][14320/40998] Data 0.553 (0.553) Elapsed 203m 39s (remain 379m 22s) Loss: 0.1690(0.2656) Grad: 1.7188  
Epoch: [1][14340/40998] Data 0.553 (0.553) Elapsed 203m 56s (remain 379m 5s) Loss: 0.1390(0.2655) Grad: 1.7920  
Epoch: [1][14360/40998] Data 0.553 (0.553) Elapsed 204m 13s (remain 378m 48s) Loss: 0.0463(0.2653) Grad: 1.2446  
Epoch: [1][14380/40998] Data 0.553 (0.553) Elapsed 204m 30s (remain 378m 31s) Loss: 0.1054(0.2652) Grad: 1.5822  
Epoch: [1][14400/40998] Data 0.553 (0.553) Elapsed 204m 48s (remain 378m 14s) Loss: 0.2709(0.2651) Grad: 2.1543  
Epoch: [1][14420/40998] Data 0.553 (0.553) Elapsed 205m 5s (remain 377m 57s) Loss: 0.2199(0.2650) Grad: 2.4892  
Epoch: [1][14440/40998] Data 0.553 (0.553) Elapsed 205m 22s (remain 377m 40s) Loss: 0.0687(0.2649) Grad: 0.7389  
Epoch: [1][14460/40998] Data 0.553 (0.553) Elapsed 205m 39s (remain 377m 23s) Loss: 0.2767(0.2648) Grad: 4.8113  
Epoch: [1][14480/40998] Data 0.553 (0.553) Elapsed 205m 56s (remain 377m 6s) Loss: 0.4241(0.2647) Grad: 3.4242  
Epoch: [1][14500/40998] Data 0.553 (0.553) Elapsed 206m 13s (remain 376m 49s) Loss: 0.0687(0.2646) Grad: 1.1305  
Epoch: [1][14520/40998] Data 0.553 (0.553) Elapsed 206m 30s (remain 376m 32s) Loss: 0.0617(0.2645) Grad: 1.2359  
Epoch: [1][14540/40998] Data 0.553 (0.553) Elapsed 206m 47s (remain 376m 15s) Loss: 0.1282(0.2644) Grad: 2.7051  
Epoch: [1][14560/40998] Data 0.552 (0.553) Elapsed 207m 4s (remain 375m 58s) Loss: 0.1404(0.2643) Grad: 2.1054  
Epoch: [1][14580/40998] Data 0.553 (0.553) Elapsed 207m 21s (remain 375m 41s) Loss: 0.0817(0.2642) Grad: 1.1332  
Epoch: [1][14600/40998] Data 0.553 (0.553) Elapsed 207m 38s (remain 375m 23s) Loss: 0.2427(0.2641) Grad: 3.0237  
Epoch: [1][14620/40998] Data 0.552 (0.553) Elapsed 207m 55s (remain 375m 6s) Loss: 0.1971(0.2639) Grad: 4.5254  
Epoch: [1][14640/40998] Data 0.553 (0.553) Elapsed 208m 12s (remain 374m 49s) Loss: 0.4228(0.2639) Grad: 2.9514  
Epoch: [1][14660/40998] Data 0.552 (0.553) Elapsed 208m 29s (remain 374m 32s) Loss: 0.1687(0.2637) Grad: 3.0671  
Epoch: [1][14680/40998] Data 0.553 (0.553) Elapsed 208m 46s (remain 374m 15s) Loss: 0.1871(0.2636) Grad: 1.8404  
Epoch: [1][14700/40998] Data 0.551 (0.553) Elapsed 209m 4s (remain 373m 58s) Loss: 0.1380(0.2635) Grad: 1.9553  
Epoch: [1][14720/40998] Data 0.553 (0.553) Elapsed 209m 21s (remain 373m 41s) Loss: 0.3043(0.2634) Grad: 2.0631  
Epoch: [1][14740/40998] Data 0.552 (0.553) Elapsed 209m 38s (remain 373m 24s) Loss: 0.2329(0.2633) Grad: 1.7245  
Epoch: [1][14760/40998] Data 0.553 (0.553) Elapsed 209m 55s (remain 373m 7s) Loss: 0.0757(0.2632) Grad: 0.9867  
Epoch: [1][14780/40998] Data 0.553 (0.553) Elapsed 210m 12s (remain 372m 50s) Loss: 0.2747(0.2630) Grad: 4.1493  
Epoch: [1][14800/40998] Data 0.552 (0.553) Elapsed 210m 29s (remain 372m 33s) Loss: 0.1762(0.2629) Grad: 2.4743  
Epoch: [1][14820/40998] Data 0.553 (0.553) Elapsed 210m 46s (remain 372m 16s) Loss: 0.1532(0.2627) Grad: 2.4626  
Epoch: [1][14840/40998] Data 0.553 (0.553) Elapsed 211m 3s (remain 371m 59s) Loss: 0.0270(0.2626) Grad: 0.4831  
Epoch: [1][14860/40998] Data 0.553 (0.553) Elapsed 211m 20s (remain 371m 42s) Loss: 0.0933(0.2625) Grad: 1.0528  
Epoch: [1][14880/40998] Data 0.554 (0.553) Elapsed 211m 37s (remain 371m 25s) Loss: 0.1074(0.2624) Grad: 1.3818  
Epoch: [1][14900/40998] Data 0.552 (0.553) Elapsed 211m 54s (remain 371m 8s) Loss: 0.1552(0.2623) Grad: 2.4732  
Epoch: [1][14920/40998] Data 0.553 (0.553) Elapsed 212m 11s (remain 370m 50s) Loss: 0.0755(0.2622) Grad: 1.0346  
Epoch: [1][14940/40998] Data 0.553 (0.553) Elapsed 212m 28s (remain 370m 33s) Loss: 0.5354(0.2621) Grad: 9.7909  
Epoch: [1][14960/40998] Data 0.553 (0.553) Elapsed 212m 45s (remain 370m 16s) Loss: 0.2245(0.2620) Grad: 4.2239  
Epoch: [1][14980/40998] Data 0.553 (0.553) Elapsed 213m 2s (remain 369m 59s) Loss: 0.2694(0.2619) Grad: 3.8000  
Epoch: [1][15000/40998] Data 0.553 (0.553) Elapsed 213m 20s (remain 369m 42s) Loss: 0.2818(0.2618) Grad: 2.8980  
Epoch: [1][15020/40998] Data 0.553 (0.553) Elapsed 213m 37s (remain 369m 25s) Loss: 0.2178(0.2617) Grad: 2.7797  
Epoch: [1][15040/40998] Data 0.553 (0.553) Elapsed 213m 54s (remain 369m 8s) Loss: 0.2140(0.2616) Grad: 3.3398  
Epoch: [1][15060/40998] Data 0.553 (0.553) Elapsed 214m 11s (remain 368m 51s) Loss: 0.0599(0.2614) Grad: 1.0378  
Epoch: [1][15080/40998] Data 0.553 (0.553) Elapsed 214m 28s (remain 368m 34s) Loss: 0.3475(0.2612) Grad: 4.3085  
Epoch: [1][15100/40998] Data 0.553 (0.553) Elapsed 214m 45s (remain 368m 17s) Loss: 0.2737(0.2612) Grad: 2.5683  
Epoch: [1][15120/40998] Data 0.553 (0.553) Elapsed 215m 2s (remain 368m 0s) Loss: 0.2075(0.2611) Grad: 4.2108  
Epoch: [1][15260/40998] Data 0.552 (0.553) Elapsed 217m 1s (remain 366m 0s) Loss: 0.3030(0.2605) Grad: 2.5120  
Epoch: [1][15280/40998] Data 0.553 (0.553) Elapsed 217m 18s (remain 365m 43s) Loss: 0.4748(0.2603) Grad: 4.0414  
Epoch: [1][15300/40998] Data 0.553 (0.553) Elapsed 217m 35s (remain 365m 26s) Loss: 0.1179(0.2602) Grad: 1.6243  
Epoch: [1][15320/40998] Data 0.552 (0.553) Elapsed 217m 53s (remain 365m 9s) Loss: 0.2600(0.2600) Grad: 2.6786  
Epoch: [1][15340/40998] Data 0.553 (0.553) Elapsed 218m 10s (remain 364m 52s) Loss: 0.0788(0.2600) Grad: 1.1963  
Epoch: [1][15360/40998] Data 0.552 (0.553) Elapsed 218m 27s (remain 364m 35s) Loss: 0.1258(0.2599) Grad: 2.2987  
Epoch: [1][15380/40998] Data 0.553 (0.553) Elapsed 218m 44s (remain 364m 18s) Loss: 0.0321(0.2598) Grad: 0.5624  
Epoch: [1][15440/40998] Data 0.553 (0.553) Elapsed 219m 35s (remain 363m 27s) Loss: 0.0484(0.2593) Grad: 0.8680  
Epoch: [1][15460/40998] Data 0.552 (0.553) Elapsed 219m 52s (remain 363m 10s) Loss: 0.2398(0.2592) Grad: 1.9606  
Epoch: [1][15480/40998] Data 0.553 (0.553) Elapsed 220m 9s (remain 362m 52s) Loss: 0.2526(0.2591) Grad: 3.7836  
Epoch: [1][15500/40998] Data 0.553 (0.553) Elapsed 220m 26s (remain 362m 35s) Loss: 0.0701(0.2590) Grad: 0.9821  
Epoch: [1][15520/40998] Data 0.553 (0.553) Elapsed 220m 43s (remain 362m 18s) Loss: 0.3168(0.2589) Grad: 3.4109  
Epoch: [1][15540/40998] Data 0.553 (0.553) Elapsed 221m 0s (remain 362m 1s) Loss: 0.1052(0.2588) Grad: 2.0457  
Epoch: [1][15560/40998] Data 0.553 (0.553) Elapsed 221m 17s (remain 361m 44s) Loss: 0.0465(0.2587) Grad: 0.8731  
Epoch: [1][15580/40998] Data 0.553 (0.553) Elapsed 221m 34s (remain 361m 27s) Loss: 0.1282(0.2586) Grad: 2.6126  
Epoch: [1][15600/40998] Data 0.553 (0.553) Elapsed 221m 51s (remain 361m 10s) Loss: 0.1525(0.2585) Grad: 1.3204  
Epoch: [1][15620/40998] Data 0.554 (0.553) Elapsed 222m 8s (remain 360m 53s) Loss: 0.1702(0.2583) Grad: 2.2571  
Epoch: [1][15640/40998] Data 0.553 (0.553) Elapsed 222m 25s (remain 360m 36s) Loss: 0.0587(0.2581) Grad: 0.8714  
Epoch: [1][15660/40998] Data 0.553 (0.553) Elapsed 222m 43s (remain 360m 19s) Loss: 0.1006(0.2580) Grad: 1.6794  
Epoch: [1][15680/40998] Data 0.553 (0.553) Elapsed 223m 0s (remain 360m 2s) Loss: 0.0593(0.2579) Grad: 1.2584  
Epoch: [1][15700/40998] Data 0.553 (0.553) Elapsed 223m 17s (remain 359m 45s) Loss: 0.0812(0.2578) Grad: 1.9852  
Epoch: [1][15720/40998] Data 0.554 (0.553) Elapsed 223m 34s (remain 359m 28s) Loss: 0.2763(0.2577) Grad: 5.4404  
Epoch: [1][15740/40998] Data 0.553 (0.553) Elapsed 223m 51s (remain 359m 11s) Loss: 0.0817(0.2576) Grad: 1.4999  
Epoch: [1][15760/40998] Data 0.553 (0.553) Elapsed 224m 8s (remain 358m 53s) Loss: 0.2203(0.2575) Grad: 1.8733  
Epoch: [1][15780/40998] Data 0.551 (0.553) Elapsed 224m 25s (remain 358m 36s) Loss: 0.1179(0.2574) Grad: 2.0530  
Epoch: [1][15800/40998] Data 0.553 (0.553) Elapsed 224m 42s (remain 358m 19s) Loss: 0.4213(0.2573) Grad: 3.6555  
Epoch: [1][15820/40998] Data 0.553 (0.553) Elapsed 224m 59s (remain 358m 2s) Loss: 0.1180(0.2572) Grad: 1.4346  
Epoch: [1][15840/40998] Data 0.552 (0.553) Elapsed 225m 16s (remain 357m 45s) Loss: 0.2150(0.2571) Grad: 2.6488  
Epoch: [1][15860/40998] Data 0.553 (0.553) Elapsed 225m 33s (remain 357m 28s) Loss: 0.1406(0.2570) Grad: 2.6654  
Epoch: [1][15880/40998] Data 0.553 (0.553) Elapsed 225m 50s (remain 357m 11s) Loss: 0.1001(0.2569) Grad: 2.0635  
Epoch: [1][15900/40998] Data 0.553 (0.553) Elapsed 226m 7s (remain 356m 54s) Loss: 0.1162(0.2568) Grad: 1.2938  
Epoch: [1][15920/40998] Data 0.553 (0.553) Elapsed 226m 24s (remain 356m 37s) Loss: 0.2432(0.2567) Grad: 2.1032  
Epoch: [1][15940/40998] Data 0.553 (0.553) Elapsed 226m 41s (remain 356m 20s) Loss: 0.1921(0.2566) Grad: 2.4337  
Epoch: [1][15960/40998] Data 0.552 (0.553) Elapsed 226m 58s (remain 356m 3s) Loss: 0.1593(0.2565) Grad: 1.9173  
Epoch: [1][15980/40998] Data 0.553 (0.553) Elapsed 227m 16s (remain 355m 46s) Loss: 0.1437(0.2564) Grad: 2.6238  
Epoch: [1][16000/40998] Data 0.553 (0.553) Elapsed 227m 33s (remain 355m 29s) Loss: 0.0911(0.2563) Grad: 1.4334  
Epoch: [1][16020/40998] Data 0.552 (0.553) Elapsed 227m 50s (remain 355m 11s) Loss: 0.0852(0.2562) Grad: 1.9538  
Epoch: [1][16040/40998] Data 0.553 (0.553) Elapsed 228m 7s (remain 354m 54s) Loss: 0.0798(0.2561) Grad: 1.8507  
Epoch: [1][16060/40998] Data 0.551 (0.553) Elapsed 228m 24s (remain 354m 37s) Loss: 0.1155(0.2560) Grad: 3.8219  
Epoch: [1][16080/40998] Data 0.553 (0.553) Elapsed 228m 41s (remain 354m 20s) Loss: 0.0496(0.2559) Grad: 0.8304  
Epoch: [1][16100/40998] Data 0.553 (0.553) Elapsed 228m 58s (remain 354m 3s) Loss: 0.2334(0.2557) Grad: 2.7643  
Epoch: [1][16120/40998] Data 0.553 (0.553) Elapsed 229m 15s (remain 353m 46s) Loss: 0.1083(0.2556) Grad: 2.2074  
Epoch: [1][16140/40998] Data 0.553 (0.553) Elapsed 229m 32s (remain 353m 29s) Loss: 0.1463(0.2555) Grad: 1.8759  
Epoch: [1][16160/40998] Data 0.553 (0.553) Elapsed 229m 49s (remain 353m 12s) Loss: 0.1010(0.2554) Grad: 2.0354  
Epoch: [1][16180/40998] Data 0.553 (0.553) Elapsed 230m 6s (remain 352m 55s) Loss: 0.3762(0.2554) Grad: 6.5814  
Epoch: [1][16200/40998] Data 0.553 (0.553) Elapsed 230m 23s (remain 352m 38s) Loss: 0.1587(0.2552) Grad: 1.7839  
Epoch: [1][16220/40998] Data 0.553 (0.553) Elapsed 230m 40s (remain 352m 21s) Loss: 0.1720(0.2551) Grad: 2.5345  
Epoch: [1][16240/40998] Data 0.552 (0.553) Elapsed 230m 57s (remain 352m 4s) Loss: 0.1951(0.2550) Grad: 3.1240  
Epoch: [1][16260/40998] Data 0.553 (0.553) Elapsed 231m 14s (remain 351m 47s) Loss: 0.1698(0.2549) Grad: 3.8891  
Epoch: [1][16280/40998] Data 0.553 (0.553) Elapsed 231m 31s (remain 351m 30s) Loss: 0.3463(0.2548) Grad: 3.1023  
Epoch: [1][16300/40998] Data 0.553 (0.553) Elapsed 231m 49s (remain 351m 13s) Loss: 0.1600(0.2548) Grad: 2.4028  
Epoch: [1][16320/40998] Data 0.553 (0.553) Elapsed 232m 6s (remain 350m 55s) Loss: 0.4809(0.2547) Grad: 5.3612  
Epoch: [1][16340/40998] Data 0.553 (0.553) Elapsed 232m 23s (remain 350m 38s) Loss: 0.1293(0.2546) Grad: 1.3372  
Epoch: [1][16360/40998] Data 0.553 (0.553) Elapsed 232m 40s (remain 350m 21s) Loss: 0.2961(0.2545) Grad: 4.6070  
Epoch: [1][16380/40998] Data 0.552 (0.553) Elapsed 232m 57s (remain 350m 4s) Loss: 0.1762(0.2545) Grad: 1.4548  
Epoch: [1][16400/40998] Data 0.552 (0.553) Elapsed 233m 14s (remain 349m 47s) Loss: 0.1788(0.2544) Grad: 1.6819  
Epoch: [1][16420/40998] Data 0.553 (0.553) Elapsed 233m 31s (remain 349m 30s) Loss: 0.3669(0.2543) Grad: 7.9949  
Epoch: [1][16440/40998] Data 0.553 (0.553) Elapsed 233m 48s (remain 349m 13s) Loss: 0.1853(0.2542) Grad: 2.2960  
Epoch: [1][16460/40998] Data 0.553 (0.553) Elapsed 234m 5s (remain 348m 56s) Loss: 0.1317(0.2541) Grad: 2.3917  
Epoch: [1][16480/40998] Data 0.553 (0.553) Elapsed 234m 22s (remain 348m 39s) Loss: 0.2448(0.2540) Grad: 2.8505  
Epoch: [1][16500/40998] Data 0.553 (0.553) Elapsed 234m 39s (remain 348m 22s) Loss: 0.1250(0.2540) Grad: 2.7002  
Epoch: [1][16520/40998] Data 0.552 (0.553) Elapsed 234m 56s (remain 348m 5s) Loss: 0.0731(0.2539) Grad: 1.0064  
Epoch: [1][16540/40998] Data 0.553 (0.553) Elapsed 235m 13s (remain 347m 48s) Loss: 0.1491(0.2538) Grad: 2.1927  
Epoch: [1][16560/40998] Data 0.552 (0.553) Elapsed 235m 30s (remain 347m 31s) Loss: 0.5655(0.2537) Grad: 4.6435  
Epoch: [1][16580/40998] Data 0.553 (0.553) Elapsed 235m 47s (remain 347m 14s) Loss: 0.0820(0.2537) Grad: 1.8205  
Epoch: [1][16600/40998] Data 0.553 (0.553) Elapsed 236m 4s (remain 346m 56s) Loss: 0.2811(0.2536) Grad: 6.2083  
Epoch: [1][16620/40998] Data 0.553 (0.553) Elapsed 236m 22s (remain 346m 39s) Loss: 0.0928(0.2535) Grad: 1.6256  
Epoch: [1][16640/40998] Data 0.552 (0.553) Elapsed 236m 39s (remain 346m 22s) Loss: 0.0480(0.2534) Grad: 0.7441  
Epoch: [1][16660/40998] Data 0.553 (0.553) Elapsed 236m 56s (remain 346m 5s) Loss: 0.2888(0.2534) Grad: 3.9175  
Epoch: [1][16680/40998] Data 0.553 (0.553) Elapsed 237m 13s (remain 345m 48s) Loss: 0.3647(0.2533) Grad: 3.3249  
Epoch: [1][16700/40998] Data 0.553 (0.553) Elapsed 237m 30s (remain 345m 31s) Loss: 0.0516(0.2531) Grad: 0.8696  
Epoch: [1][16720/40998] Data 0.553 (0.553) Elapsed 237m 47s (remain 345m 14s) Loss: 0.1479(0.2531) Grad: 2.9752  
Epoch: [1][16740/40998] Data 0.553 (0.553) Elapsed 238m 4s (remain 344m 57s) Loss: 0.4763(0.2530) Grad: 5.1271  
Epoch: [1][16760/40998] Data 0.553 (0.553) Elapsed 238m 21s (remain 344m 40s) Loss: 0.1313(0.2529) Grad: 1.9964  
Epoch: [1][16780/40998] Data 0.553 (0.553) Elapsed 238m 38s (remain 344m 23s) Loss: 0.0684(0.2528) Grad: 1.3994  
Epoch: [1][16800/40998] Data 0.552 (0.553) Elapsed 238m 55s (remain 344m 6s) Loss: 0.1320(0.2527) Grad: 1.5602  
Epoch: [1][16820/40998] Data 0.553 (0.553) Elapsed 239m 12s (remain 343m 49s) Loss: 0.3163(0.2526) Grad: 3.2209  
Epoch: [1][16840/40998] Data 0.552 (0.553) Elapsed 239m 29s (remain 343m 32s) Loss: 0.0771(0.2525) Grad: 0.6902  
Epoch: [1][16860/40998] Data 0.553 (0.553) Elapsed 239m 46s (remain 343m 15s) Loss: 0.0243(0.2524) Grad: 0.5582  
Epoch: [1][16880/40998] Data 0.553 (0.553) Elapsed 240m 3s (remain 342m 57s) Loss: 0.2015(0.2523) Grad: 1.9992  
Epoch: [1][16900/40998] Data 0.553 (0.553) Elapsed 240m 20s (remain 342m 40s) Loss: 0.3221(0.2523) Grad: 3.6064  
Epoch: [1][16920/40998] Data 0.553 (0.553) Elapsed 240m 37s (remain 342m 23s) Loss: 0.3331(0.2522) Grad: 3.8411  
Epoch: [1][16940/40998] Data 0.553 (0.553) Elapsed 240m 54s (remain 342m 6s) Loss: 0.1680(0.2520) Grad: 1.7458  
Epoch: [1][16960/40998] Data 0.553 (0.553) Elapsed 241m 12s (remain 341m 49s) Loss: 0.1518(0.2519) Grad: 4.4711  
Epoch: [1][16980/40998] Data 0.553 (0.553) Elapsed 241m 29s (remain 341m 32s) Loss: 0.0633(0.2518) Grad: 1.2474  
Epoch: [1][17000/40998] Data 0.553 (0.553) Elapsed 241m 46s (remain 341m 15s) Loss: 0.1849(0.2517) Grad: 1.9928  
Epoch: [1][17020/40998] Data 0.553 (0.553) Elapsed 242m 3s (remain 340m 58s) Loss: 0.1266(0.2517) Grad: 2.7471  
Epoch: [1][17040/40998] Data 0.553 (0.553) Elapsed 242m 20s (remain 340m 41s) Loss: 0.0536(0.2516) Grad: 0.5454  
Epoch: [1][17060/40998] Data 0.553 (0.553) Elapsed 242m 37s (remain 340m 24s) Loss: 0.1152(0.2515) Grad: 1.4696  
Epoch: [1][17080/40998] Data 0.552 (0.553) Elapsed 242m 54s (remain 340m 7s) Loss: 0.1169(0.2514) Grad: 3.1552  
Epoch: [1][17100/40998] Data 0.553 (0.553) Elapsed 243m 11s (remain 339m 50s) Loss: 0.1567(0.2513) Grad: 1.6248  
Epoch: [1][17120/40998] Data 0.552 (0.553) Elapsed 243m 28s (remain 339m 33s) Loss: 0.2288(0.2512) Grad: 4.0448  
Epoch: [1][17140/40998] Data 0.552 (0.553) Elapsed 243m 45s (remain 339m 16s) Loss: 0.1395(0.2512) Grad: 2.3428  
Epoch: [1][17160/40998] Data 0.553 (0.553) Elapsed 244m 2s (remain 338m 58s) Loss: 0.4210(0.2511) Grad: 3.0639  
Epoch: [1][17180/40998] Data 0.553 (0.553) Elapsed 244m 19s (remain 338m 41s) Loss: 0.2881(0.2511) Grad: 3.2067  
Epoch: [1][17200/40998] Data 0.553 (0.553) Elapsed 244m 36s (remain 338m 24s) Loss: 0.0448(0.2510) Grad: 0.6094  
Epoch: [1][17220/40998] Data 0.553 (0.553) Elapsed 244m 53s (remain 338m 7s) Loss: 0.1482(0.2508) Grad: 2.8622  
Epoch: [1][17240/40998] Data 0.553 (0.553) Elapsed 245m 10s (remain 337m 50s) Loss: 0.1655(0.2508) Grad: 1.6844  
Epoch: [1][17260/40998] Data 0.552 (0.553) Elapsed 245m 27s (remain 337m 33s) Loss: 0.1472(0.2507) Grad: 2.3813  
Epoch: [1][17280/40998] Data 0.553 (0.553) Elapsed 245m 45s (remain 337m 16s) Loss: 0.3899(0.2506) Grad: 3.2701  
Epoch: [1][17300/40998] Data 0.553 (0.553) Elapsed 246m 2s (remain 336m 59s) Loss: 0.0811(0.2506) Grad: 0.5731  
Epoch: [1][17320/40998] Data 0.553 (0.553) Elapsed 246m 19s (remain 336m 42s) Loss: 0.1597(0.2504) Grad: 1.9984  
Epoch: [1][17340/40998] Data 0.553 (0.553) Elapsed 246m 36s (remain 336m 25s) Loss: 0.0406(0.2504) Grad: 0.9213  
Epoch: [1][17360/40998] Data 0.552 (0.553) Elapsed 246m 53s (remain 336m 8s) Loss: 0.2431(0.2503) Grad: 2.9174  
Epoch: [1][17380/40998] Data 0.553 (0.553) Elapsed 247m 10s (remain 335m 51s) Loss: 0.2570(0.2502) Grad: 4.4230  
Epoch: [1][17400/40998] Data 0.553 (0.553) Elapsed 247m 27s (remain 335m 34s) Loss: 0.0810(0.2501) Grad: 0.9942  
Epoch: [1][17420/40998] Data 0.553 (0.553) Elapsed 247m 44s (remain 335m 17s) Loss: 0.0589(0.2501) Grad: 1.0411  
Epoch: [1][17440/40998] Data 0.552 (0.553) Elapsed 248m 1s (remain 334m 59s) Loss: 0.0639(0.2500) Grad: 1.6222  
Epoch: [1][17460/40998] Data 0.553 (0.553) Elapsed 248m 18s (remain 334m 42s) Loss: 0.0288(0.2499) Grad: 0.3935  
Epoch: [1][17480/40998] Data 0.553 (0.553) Elapsed 248m 35s (remain 334m 25s) Loss: 0.1202(0.2498) Grad: 1.3275  
Epoch: [1][17500/40998] Data 0.553 (0.553) Elapsed 248m 52s (remain 334m 8s) Loss: 0.1213(0.2497) Grad: 2.3754  
Epoch: [1][17520/40998] Data 0.553 (0.553) Elapsed 249m 9s (remain 333m 51s) Loss: 0.0334(0.2496) Grad: 0.5979  
Epoch: [1][17540/40998] Data 0.552 (0.553) Elapsed 249m 26s (remain 333m 34s) Loss: 0.0470(0.2495) Grad: 1.3209  
Epoch: [1][17560/40998] Data 0.553 (0.553) Elapsed 249m 43s (remain 333m 17s) Loss: 0.0873(0.2494) Grad: 1.1118  
Epoch: [1][17580/40998] Data 0.552 (0.553) Elapsed 250m 0s (remain 333m 0s) Loss: 0.1014(0.2493) Grad: 2.4408  
Epoch: [1][17600/40998] Data 0.552 (0.553) Elapsed 250m 17s (remain 332m 43s) Loss: 0.1332(0.2493) Grad: 4.1381  
Epoch: [1][17620/40998] Data 0.553 (0.553) Elapsed 250m 35s (remain 332m 26s) Loss: 0.0795(0.2491) Grad: 1.2570  
Epoch: [1][17640/40998] Data 0.553 (0.553) Elapsed 250m 52s (remain 332m 9s) Loss: 0.1161(0.2490) Grad: 1.9118  
Epoch: [1][17660/40998] Data 0.553 (0.553) Elapsed 251m 9s (remain 331m 52s) Loss: 0.1031(0.2490) Grad: 1.3117  
Epoch: [1][17680/40998] Data 0.552 (0.553) Elapsed 251m 26s (remain 331m 35s) Loss: 0.1873(0.2489) Grad: 2.8792  
Epoch: [1][17700/40998] Data 0.553 (0.553) Elapsed 251m 43s (remain 331m 18s) Loss: 0.1755(0.2488) Grad: 3.4949  
Epoch: [1][17720/40998] Data 0.553 (0.553) Elapsed 252m 0s (remain 331m 0s) Loss: 0.2286(0.2487) Grad: 2.9451  
Epoch: [1][17740/40998] Data 0.553 (0.553) Elapsed 252m 17s (remain 330m 43s) Loss: 0.3279(0.2487) Grad: 2.4857  
Epoch: [1][17760/40998] Data 0.553 (0.553) Elapsed 252m 34s (remain 330m 26s) Loss: 0.2772(0.2486) Grad: 3.0324  
Epoch: [1][17780/40998] Data 0.552 (0.553) Elapsed 252m 51s (remain 330m 9s) Loss: 0.0954(0.2486) Grad: 1.4501  
Epoch: [1][17800/40998] Data 0.553 (0.553) Elapsed 253m 8s (remain 329m 52s) Loss: 0.1740(0.2485) Grad: 1.1246  
Epoch: [1][17820/40998] Data 0.552 (0.553) Elapsed 253m 25s (remain 329m 35s) Loss: 0.0844(0.2484) Grad: 1.5910  
Epoch: [1][17840/40998] Data 0.553 (0.553) Elapsed 253m 42s (remain 329m 18s) Loss: 0.2233(0.2482) Grad: 2.6648  
Epoch: [1][17860/40998] Data 0.552 (0.553) Elapsed 253m 59s (remain 329m 1s) Loss: 0.2347(0.2482) Grad: 4.7091  
Epoch: [1][17880/40998] Data 0.553 (0.553) Elapsed 254m 16s (remain 328m 44s) Loss: 0.1699(0.2482) Grad: 1.5545  
Epoch: [1][17900/40998] Data 0.552 (0.553) Elapsed 254m 33s (remain 328m 27s) Loss: 0.1253(0.2481) Grad: 1.7767  
Epoch: [1][17920/40998] Data 0.553 (0.553) Elapsed 254m 50s (remain 328m 10s) Loss: 0.0736(0.2480) Grad: 1.0490  
Epoch: [1][17940/40998] Data 0.553 (0.553) Elapsed 255m 8s (remain 327m 53s) Loss: 0.0950(0.2479) Grad: 0.9803  
Epoch: [1][17960/40998] Data 0.553 (0.553) Elapsed 255m 25s (remain 327m 36s) Loss: 0.0361(0.2478) Grad: 0.8263  
Epoch: [1][17980/40998] Data 0.553 (0.553) Elapsed 255m 42s (remain 327m 19s) Loss: 0.1219(0.2477) Grad: 1.7528  
Epoch: [1][18000/40998] Data 0.553 (0.553) Elapsed 255m 59s (remain 327m 1s) Loss: 0.4527(0.2476) Grad: 3.1047  
Epoch: [1][18020/40998] Data 0.553 (0.553) Elapsed 256m 16s (remain 326m 44s) Loss: 0.1700(0.2474) Grad: 2.7845  
Epoch: [1][18040/40998] Data 0.553 (0.553) Elapsed 256m 33s (remain 326m 27s) Loss: 0.1530(0.2473) Grad: 4.6904  
Epoch: [1][18060/40998] Data 0.552 (0.553) Elapsed 256m 50s (remain 326m 10s) Loss: 0.1233(0.2472) Grad: 1.7688  
Epoch: [1][18080/40998] Data 0.553 (0.553) Elapsed 257m 7s (remain 325m 53s) Loss: 0.1040(0.2471) Grad: 1.4221  
Epoch: [1][18100/40998] Data 0.552 (0.553) Elapsed 257m 24s (remain 325m 36s) Loss: 0.0451(0.2470) Grad: 0.9076  
Epoch: [1][18120/40998] Data 0.553 (0.553) Elapsed 257m 41s (remain 325m 19s) Loss: 0.0446(0.2469) Grad: 0.7721  
Epoch: [1][18140/40998] Data 0.553 (0.553) Elapsed 257m 58s (remain 325m 2s) Loss: 0.2053(0.2469) Grad: 2.2058  
Epoch: [1][18160/40998] Data 0.553 (0.553) Elapsed 258m 15s (remain 324m 45s) Loss: 0.1498(0.2467) Grad: 1.9591  
Epoch: [1][18180/40998] Data 0.553 (0.553) Elapsed 258m 32s (remain 324m 28s) Loss: 0.0827(0.2466) Grad: 2.1156  
Epoch: [1][18200/40998] Data 0.554 (0.553) Elapsed 258m 49s (remain 324m 11s) Loss: 0.2338(0.2466) Grad: 3.8750  
Epoch: [1][18220/40998] Data 0.553 (0.553) Elapsed 259m 6s (remain 323m 54s) Loss: 0.1064(0.2465) Grad: 1.6336  
Epoch: [1][18240/40998] Data 0.553 (0.553) Elapsed 259m 23s (remain 323m 37s) Loss: 0.0426(0.2464) Grad: 0.5867  
Epoch: [1][18260/40998] Data 0.552 (0.553) Elapsed 259m 41s (remain 323m 20s) Loss: 0.5421(0.2463) Grad: 7.1349  
Epoch: [1][18280/40998] Data 0.552 (0.553) Elapsed 259m 58s (remain 323m 3s) Loss: 0.1671(0.2462) Grad: 4.0544  
Epoch: [1][18300/40998] Data 0.553 (0.553) Elapsed 260m 15s (remain 322m 45s) Loss: 0.1459(0.2461) Grad: 1.7394  
Epoch: [1][18320/40998] Data 0.553 (0.553) Elapsed 260m 32s (remain 322m 28s) Loss: 0.0829(0.2460) Grad: 1.5254  
Epoch: [1][18340/40998] Data 0.553 (0.553) Elapsed 260m 49s (remain 322m 11s) Loss: 0.1042(0.2459) Grad: 2.5611  
Epoch: [1][18360/40998] Data 0.553 (0.553) Elapsed 261m 6s (remain 321m 54s) Loss: 0.0661(0.2459) Grad: 1.0508  
Epoch: [1][18380/40998] Data 0.552 (0.553) Elapsed 261m 23s (remain 321m 37s) Loss: 0.0505(0.2458) Grad: 0.7083  
Epoch: [1][18400/40998] Data 0.552 (0.553) Elapsed 261m 40s (remain 321m 20s) Loss: 0.3217(0.2457) Grad: 2.5998  
Epoch: [1][18420/40998] Data 0.553 (0.553) Elapsed 261m 57s (remain 321m 3s) Loss: 0.0896(0.2456) Grad: 1.1268  
Epoch: [1][18440/40998] Data 0.553 (0.553) Elapsed 262m 14s (remain 320m 46s) Loss: 0.1998(0.2455) Grad: 2.3304  
Epoch: [1][18460/40998] Data 0.553 (0.553) Elapsed 262m 31s (remain 320m 29s) Loss: 0.1079(0.2454) Grad: 2.0039  
Epoch: [1][18480/40998] Data 0.553 (0.553) Elapsed 262m 48s (remain 320m 12s) Loss: 0.1348(0.2453) Grad: 2.3925  
Epoch: [1][18500/40998] Data 0.552 (0.553) Elapsed 263m 5s (remain 319m 55s) Loss: 0.1331(0.2452) Grad: 2.6048  
Epoch: [1][18520/40998] Data 0.552 (0.553) Elapsed 263m 22s (remain 319m 38s) Loss: 0.1928(0.2451) Grad: 1.7789  
Epoch: [1][18540/40998] Data 0.553 (0.553) Elapsed 263m 39s (remain 319m 21s) Loss: 0.1155(0.2450) Grad: 2.1506  
Epoch: [1][18560/40998] Data 0.553 (0.553) Elapsed 263m 56s (remain 319m 4s) Loss: 0.0387(0.2450) Grad: 0.5673  
Epoch: [1][18580/40998] Data 0.553 (0.553) Elapsed 264m 13s (remain 318m 46s) Loss: 0.2733(0.2449) Grad: 2.6358  
Epoch: [1][18600/40998] Data 0.553 (0.553) Elapsed 264m 31s (remain 318m 29s) Loss: 0.2935(0.2448) Grad: 3.2840  
Epoch: [1][18620/40998] Data 0.553 (0.553) Elapsed 264m 48s (remain 318m 12s) Loss: 0.2186(0.2448) Grad: 4.8716  
Epoch: [1][18640/40998] Data 0.553 (0.553) Elapsed 265m 5s (remain 317m 55s) Loss: 0.1864(0.2447) Grad: 2.4338  
Epoch: [1][18660/40998] Data 0.553 (0.553) Elapsed 265m 22s (remain 317m 38s) Loss: 0.1841(0.2446) Grad: 2.1224  
Epoch: [1][18680/40998] Data 0.552 (0.553) Elapsed 265m 39s (remain 317m 21s) Loss: 0.2263(0.2445) Grad: 3.0600  
Epoch: [1][18700/40998] Data 0.553 (0.553) Elapsed 265m 56s (remain 317m 4s) Loss: 0.3420(0.2444) Grad: 2.1759  
Epoch: [1][18720/40998] Data 0.552 (0.553) Elapsed 266m 13s (remain 316m 47s) Loss: 0.2274(0.2443) Grad: 4.0982  
Epoch: [1][18740/40998] Data 0.553 (0.553) Elapsed 266m 30s (remain 316m 30s) Loss: 0.3606(0.2443) Grad: 4.0770  
Epoch: [1][18760/40998] Data 0.553 (0.553) Elapsed 266m 47s (remain 316m 13s) Loss: 0.0340(0.2442) Grad: 0.5353  
Epoch: [1][18780/40998] Data 0.553 (0.553) Elapsed 267m 4s (remain 315m 56s) Loss: 0.1919(0.2441) Grad: 3.8375  
Epoch: [1][18800/40998] Data 0.553 (0.553) Elapsed 267m 21s (remain 315m 39s) Loss: 0.1648(0.2440) Grad: 2.5278  
Epoch: [1][18820/40998] Data 0.552 (0.553) Elapsed 267m 38s (remain 315m 22s) Loss: 0.0451(0.2439) Grad: 0.8171  
Epoch: [1][18840/40998] Data 0.552 (0.553) Elapsed 267m 55s (remain 315m 5s) Loss: 0.2507(0.2438) Grad: 2.6086  
Epoch: [1][18860/40998] Data 0.553 (0.553) Elapsed 268m 12s (remain 314m 47s) Loss: 0.1341(0.2437) Grad: 1.7936  
Epoch: [1][18880/40998] Data 0.553 (0.553) Elapsed 268m 29s (remain 314m 30s) Loss: 0.2315(0.2437) Grad: 4.1346  
Epoch: [1][18900/40998] Data 0.552 (0.553) Elapsed 268m 46s (remain 314m 13s) Loss: 0.3452(0.2435) Grad: 2.8581  
Epoch: [1][18920/40998] Data 0.553 (0.553) Elapsed 269m 3s (remain 313m 56s) Loss: 0.1273(0.2435) Grad: 1.6399  
Epoch: [1][18940/40998] Data 0.553 (0.553) Elapsed 269m 21s (remain 313m 39s) Loss: 0.1423(0.2434) Grad: 1.7945  
Epoch: [1][18960/40998] Data 0.552 (0.553) Elapsed 269m 38s (remain 313m 22s) Loss: 0.0461(0.2433) Grad: 0.6477  
Epoch: [1][18980/40998] Data 0.553 (0.553) Elapsed 269m 55s (remain 313m 5s) Loss: 0.0917(0.2432) Grad: 1.9324  
Epoch: [1][19000/40998] Data 0.553 (0.553) Elapsed 270m 12s (remain 312m 48s) Loss: 0.0565(0.2432) Grad: 1.1468  
Epoch: [1][19020/40998] Data 0.553 (0.553) Elapsed 270m 29s (remain 312m 31s) Loss: 0.2122(0.2431) Grad: 5.7390  
Epoch: [1][19040/40998] Data 0.553 (0.553) Elapsed 270m 46s (remain 312m 14s) Loss: 0.3473(0.2430) Grad: 3.8021  
Epoch: [1][19060/40998] Data 0.553 (0.553) Elapsed 271m 3s (remain 311m 57s) Loss: 0.2280(0.2429) Grad: 3.0202  
Epoch: [1][19080/40998] Data 0.553 (0.553) Elapsed 271m 20s (remain 311m 40s) Loss: 0.1162(0.2429) Grad: 1.2837  
Epoch: [1][19100/40998] Data 0.553 (0.553) Elapsed 271m 37s (remain 311m 23s) Loss: 0.0909(0.2428) Grad: 1.5304  
Epoch: [1][19120/40998] Data 0.552 (0.553) Elapsed 271m 54s (remain 311m 6s) Loss: 0.1120(0.2427) Grad: 1.6885  
Epoch: [1][19140/40998] Data 0.553 (0.553) Elapsed 272m 11s (remain 310m 49s) Loss: 0.1221(0.2427) Grad: 2.0557  
Epoch: [1][19160/40998] Data 0.553 (0.553) Elapsed 272m 28s (remain 310m 31s) Loss: 0.2757(0.2425) Grad: 2.7100  
Epoch: [1][19180/40998] Data 0.553 (0.553) Elapsed 272m 45s (remain 310m 14s) Loss: 0.0795(0.2425) Grad: 1.6499  
Epoch: [1][19200/40998] Data 0.552 (0.553) Elapsed 273m 2s (remain 309m 57s) Loss: 0.2080(0.2424) Grad: 1.6410  
Epoch: [1][19220/40998] Data 0.553 (0.553) Elapsed 273m 19s (remain 309m 40s) Loss: 0.0771(0.2423) Grad: 0.7612  
Epoch: [1][19240/40998] Data 0.553 (0.553) Elapsed 273m 36s (remain 309m 23s) Loss: 0.4706(0.2423) Grad: 4.3592  
Epoch: [1][19260/40998] Data 0.553 (0.553) Elapsed 273m 54s (remain 309m 6s) Loss: 0.3268(0.2422) Grad: 2.8371  
Epoch: [1][19280/40998] Data 0.553 (0.553) Elapsed 274m 11s (remain 308m 49s) Loss: 0.3448(0.2421) Grad: 4.6229  
Epoch: [1][19300/40998] Data 0.552 (0.553) Elapsed 274m 28s (remain 308m 32s) Loss: 0.0629(0.2420) Grad: 0.8713  
Epoch: [1][19320/40998] Data 0.553 (0.553) Elapsed 274m 45s (remain 308m 15s) Loss: 0.2597(0.2420) Grad: 3.0607  
Epoch: [1][19340/40998] Data 0.553 (0.553) Elapsed 275m 2s (remain 307m 58s) Loss: 0.1673(0.2419) Grad: 2.5576  
Epoch: [1][19360/40998] Data 0.552 (0.553) Elapsed 275m 19s (remain 307m 41s) Loss: 0.2655(0.2418) Grad: 2.3812  
Epoch: [1][19380/40998] Data 0.553 (0.553) Elapsed 275m 36s (remain 307m 24s) Loss: 0.0996(0.2417) Grad: 2.2970  
Epoch: [1][19400/40998] Data 0.553 (0.553) Elapsed 275m 53s (remain 307m 7s) Loss: 0.0725(0.2417) Grad: 0.9698  
Epoch: [1][19420/40998] Data 0.553 (0.553) Elapsed 276m 10s (remain 306m 50s) Loss: 0.0410(0.2416) Grad: 0.6932  
Epoch: [1][19440/40998] Data 0.553 (0.553) Elapsed 276m 27s (remain 306m 33s) Loss: 0.1619(0.2415) Grad: 1.6838  
Epoch: [1][19460/40998] Data 0.553 (0.553) Elapsed 276m 44s (remain 306m 15s) Loss: 0.3124(0.2415) Grad: 3.8552  
Epoch: [1][19480/40998] Data 0.553 (0.553) Elapsed 277m 1s (remain 305m 58s) Loss: 0.1986(0.2414) Grad: 2.7909  
Epoch: [1][19500/40998] Data 0.553 (0.553) Elapsed 277m 18s (remain 305m 41s) Loss: 0.2110(0.2414) Grad: 2.3025  
Epoch: [1][19520/40998] Data 0.553 (0.553) Elapsed 277m 35s (remain 305m 24s) Loss: 0.0698(0.2413) Grad: 0.7284  
Epoch: [1][19540/40998] Data 0.552 (0.553) Elapsed 277m 52s (remain 305m 7s) Loss: 0.1746(0.2412) Grad: 2.9130  
Epoch: [1][19560/40998] Data 0.553 (0.553) Elapsed 278m 9s (remain 304m 50s) Loss: 0.0882(0.2412) Grad: 1.0842  
Epoch: [1][19580/40998] Data 0.553 (0.553) Elapsed 278m 27s (remain 304m 33s) Loss: 0.1096(0.2411) Grad: 1.2281  
Epoch: [1][19600/40998] Data 0.553 (0.553) Elapsed 278m 44s (remain 304m 16s) Loss: 0.0565(0.2410) Grad: 1.2948  
Epoch: [1][19620/40998] Data 0.552 (0.553) Elapsed 279m 1s (remain 303m 59s) Loss: 0.1108(0.2409) Grad: 1.2629  
Epoch: [1][19640/40998] Data 0.552 (0.553) Elapsed 279m 18s (remain 303m 42s) Loss: 0.1209(0.2408) Grad: 1.5887  
Epoch: [1][19660/40998] Data 0.553 (0.553) Elapsed 279m 35s (remain 303m 25s) Loss: 0.1309(0.2408) Grad: 2.9003  
Epoch: [1][19680/40998] Data 0.552 (0.553) Elapsed 279m 52s (remain 303m 8s) Loss: 0.1696(0.2407) Grad: 1.5461  
Epoch: [1][19700/40998] Data 0.552 (0.553) Elapsed 280m 9s (remain 302m 51s) Loss: 0.2936(0.2406) Grad: 4.2503  
Epoch: [1][19720/40998] Data 0.553 (0.553) Elapsed 280m 26s (remain 302m 34s) Loss: 0.4948(0.2405) Grad: 3.9454  
Epoch: [1][19740/40998] Data 0.553 (0.553) Elapsed 280m 43s (remain 302m 17s) Loss: 0.1811(0.2404) Grad: 2.4617  
Epoch: [1][19760/40998] Data 0.553 (0.553) Elapsed 281m 0s (remain 301m 59s) Loss: 0.0719(0.2404) Grad: 0.8475  
Epoch: [1][19780/40998] Data 0.552 (0.553) Elapsed 281m 17s (remain 301m 42s) Loss: 0.0738(0.2402) Grad: 0.7520  
Epoch: [1][19800/40998] Data 0.553 (0.553) Elapsed 281m 34s (remain 301m 25s) Loss: 0.0722(0.2402) Grad: 1.4441  
Epoch: [1][19820/40998] Data 0.553 (0.553) Elapsed 281m 51s (remain 301m 8s) Loss: 0.0865(0.2401) Grad: 1.5266  
Epoch: [1][19840/40998] Data 0.553 (0.553) Elapsed 282m 8s (remain 300m 51s) Loss: 0.1173(0.2400) Grad: 1.5981  
Epoch: [1][19860/40998] Data 0.553 (0.553) Elapsed 282m 25s (remain 300m 34s) Loss: 0.0407(0.2399) Grad: 0.8835  
Epoch: [1][19880/40998] Data 0.553 (0.553) Elapsed 282m 42s (remain 300m 17s) Loss: 0.2849(0.2398) Grad: 2.5711  
Epoch: [1][19900/40998] Data 0.553 (0.553) Elapsed 283m 0s (remain 300m 0s) Loss: 0.7922(0.2398) Grad: 15.4472  
Epoch: [1][19920/40998] Data 0.553 (0.553) Elapsed 283m 17s (remain 299m 43s) Loss: 0.1902(0.2398) Grad: 1.2798  
Epoch: [1][19940/40998] Data 0.553 (0.553) Elapsed 283m 34s (remain 299m 26s) Loss: 0.2658(0.2397) Grad: 2.5439  
Epoch: [1][19960/40998] Data 0.553 (0.553) Elapsed 283m 51s (remain 299m 9s) Loss: 0.1158(0.2396) Grad: 2.5047  
Epoch: [1][19980/40998] Data 0.552 (0.553) Elapsed 284m 8s (remain 298m 52s) Loss: 0.1399(0.2396) Grad: 1.3208  
In [ ]: